From 8e7331e95ddd39823121d7500b4667723c10e084 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 22 Apr 2025 16:45:13 -0400 Subject: [PATCH 1/3] New bulk scorer for binary quantized vectors via optimized scalar quantization --- .../benchmark/vector/OSQScorerBenchmark.java | 204 ++++++++ .../DefaultESVectorizationProvider.java | 9 + .../vectorization/ES91OSQVectorsScorer.java | 166 +++++++ .../ESVectorizationProvider.java | 6 + .../ESVectorizationProvider.java | 5 + .../MemorySegmentES91OSQVectorsScorer.java | 450 ++++++++++++++++++ .../PanamaESVectorizationProvider.java | 17 + .../ES91OSQVectorScorerTests.java | 103 ++++ 8 files changed, 960 insertions(+) create mode 100644 benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java create mode 100644 libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java create mode 100644 libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java create mode 100644 libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java new file mode 100644 index 0000000000000..486919f10bcf5 --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java @@ -0,0 +1,204 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.benchmark.vector; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.IOException; +import java.nio.file.Files; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +// first iteration is complete garbage, so make sure we really warmup +@Warmup(iterations = 4, time = 1) +// real iterations. not useful to spend tons of time here, better to fork more +@Measurement(iterations = 5, time = 1) +// engage some noise reduction +@Fork(value = 1) +public class OSQScorerBenchmark { + + static { + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Param({ "1024" }) + int dims; + + int length; + + int numVectors = ES91OSQVectorsScorer.BULK_SIZE * 10; + int numQueries = 10; + + byte[][] binaryVectors; + byte[][] binaryQueries; + OptimizedScalarQuantizer.QuantizationResult result; + float centroidDp; + + byte[] scratch; + ES91OSQVectorsScorer scorer; + + IndexInput in; + + float[] scratchScores; + float[] corrections; + + @Setup + public void setup() throws IOException { + Random random = new Random(123); + + this.length = OptimizedScalarQuantizer.discretize(dims, 64) / 8; + + binaryVectors = new byte[numVectors][length]; + for (byte[] binaryVector : binaryVectors) { + random.nextBytes(binaryVector); + } + + Directory dir = new MMapDirectory(Files.createTempDirectory("vectorData")); + IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT); + byte[] correctionBytes = new byte[14 * ES91OSQVectorsScorer.BULK_SIZE]; + for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) { + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + out.writeBytes(binaryVectors[i + j], 0, binaryVectors[i + j].length); + } + random.nextBytes(correctionBytes); + out.writeBytes(correctionBytes, 0, correctionBytes.length); + } + out.close(); + in = dir.openInput("vectors", IOContext.DEFAULT); + + binaryQueries = new byte[numVectors][4 * length]; + for (byte[] binaryVector : binaryVectors) { + random.nextBytes(binaryVector); + } + result = new OptimizedScalarQuantizer.QuantizationResult( + random.nextFloat(), + random.nextFloat(), + random.nextFloat(), + Short.toUnsignedInt((short) random.nextInt()) + ); + centroidDp = random.nextFloat(); + + scratch = new byte[length]; + scorer = ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(in, dims); + scratchScores = new float[16]; + corrections = new float[3]; + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void scoreFromArray(Blackhole bh) throws IOException { + for (int j = 0; j < numQueries; j++) { + in.seek(0); + for (int i = 0; i < numVectors; i++) { + in.readBytes(scratch, 0, length); + float qDist = VectorUtil.int4BitDotProduct(binaryQueries[j], scratch); + in.readFloats(corrections, 0, corrections.length); + int addition = Short.toUnsignedInt(in.readShort()); + float score = scorer.score( + result, + VectorSimilarityFunction.EUCLIDEAN, + centroidDp, + corrections[0], + corrections[1], + addition, + corrections[2], + qDist + ); + bh.consume(score); + } + } + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException { + for (int j = 0; j < numQueries; j++) { + in.seek(0); + for (int i = 0; i < numVectors; i++) { + float qDist = scorer.quantizeScore(binaryQueries[j]); + in.readFloats(corrections, 0, corrections.length); + int addition = Short.toUnsignedInt(in.readShort()); + float score = scorer.score( + result, + VectorSimilarityFunction.EUCLIDEAN, + centroidDp, + corrections[0], + corrections[1], + addition, + corrections[2], + qDist + ); + bh.consume(score); + } + } + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh) throws IOException { + for (int j = 0; j < numQueries; j++) { + in.seek(0); + for (int i = 0; i < numVectors; i += 16) { + scorer.quantizeScoreBulk(binaryQueries[j], ES91OSQVectorsScorer.BULK_SIZE, scratchScores); + for (int k = 0; k < ES91OSQVectorsScorer.BULK_SIZE; k++) { + in.readFloats(corrections, 0, corrections.length); + int addition = Short.toUnsignedInt(in.readShort()); + float score = scorer.score( + result, + VectorSimilarityFunction.EUCLIDEAN, + centroidDp, + corrections[0], + corrections[1], + addition, + corrections[2], + scratchScores[k] + ); + bh.consume(score); + } + } + } + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void scoreFromMemorySegmentAllBulk(Blackhole bh) throws IOException { + for (int j = 0; j < numQueries; j++) { + in.seek(0); + for (int i = 0; i < numVectors; i += 16) { + scorer.scoreBulk(binaryQueries[j], result, VectorSimilarityFunction.EUCLIDEAN, centroidDp, scratchScores); + bh.consume(scratchScores); + } + } + } +} diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java index 6c0f7ed146b86..e8ff6f83f2172 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java @@ -9,6 +9,10 @@ package org.elasticsearch.simdvec.internal.vectorization; +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; + final class DefaultESVectorizationProvider extends ESVectorizationProvider { private final ESVectorUtilSupport vectorUtilSupport; @@ -20,4 +24,9 @@ final class DefaultESVectorizationProvider extends ESVectorizationProvider { public ESVectorUtilSupport getVectorUtilSupport() { return vectorUtilSupport; } + + @Override + public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException { + return new ES91OSQVectorsScorer(input, dimension); + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java new file mode 100644 index 0000000000000..66a8e5b89855d --- /dev/null +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.simdvec.internal.vectorization; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.BitUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Scorer for quantized vectors stored as an {@link IndexInput}. */ +public class ES91OSQVectorsScorer { + + public static final int BULK_SIZE = 16; + + protected static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); + + /** The wrapper {@link IndexInput}. */ + protected final IndexInput in; + + protected final int length; + protected final int dimensions; + + protected final float[] lowerIntervals = new float[BULK_SIZE]; + protected final float[] upperIntervals = new float[BULK_SIZE]; + protected final int[] targetComponentSums = new int[BULK_SIZE]; + protected final float[] additionalCorrections = new float[BULK_SIZE]; + + /** Sole constructor, called by sub-classes. */ + public ES91OSQVectorsScorer(IndexInput in, int dimensions) { + this.in = in; + this.dimensions = dimensions; + this.length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8; + } + + /** + * compute the quantize distance between the provided quantized query and the quantized vector + * that is read from the wrapped {@link IndexInput}. + */ + public long quantizeScore(byte[] q) throws IOException { + assert q.length == length * 4; + final int size = length; + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int r = 0; + for (final int upperBound = size & -Long.BYTES; r < upperBound; r += Long.BYTES) { + final long value = in.readLong(); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + size) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 2 * size) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 3 * size) & value); + } + for (final int upperBound = size & -Integer.BYTES; r < upperBound; r += Integer.BYTES) { + final int value = in.readInt(); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + size) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 2 * size) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 3 * size) & value); + } + for (; r < size; r++) { + final byte value = in.readByte(); + subRet0 += Integer.bitCount((q[r] & value) & 0xFF); + subRet1 += Integer.bitCount((q[r + size] & value) & 0xFF); + subRet2 += Integer.bitCount((q[r + 2 * size] & value) & 0xFF); + subRet3 += Integer.bitCount((q[r + 3 * size] & value) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + /** + * compute the quantize distance between the provided quantized query and the quantized vectors + * that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is + * determined by {code count} and the results are stored in the provided {@code scores} array. + */ + public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { + for (int i = 0; i < count; i++) { + scores[i] = quantizeScore(q); + } + } + + /** + * Computes the score by applying the necessary corrections to the provided quantized distance. + */ + public float score( + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float lowerInterval, + float upperInterval, + int targetComponentSum, + float additionalCorrection, + float qcDist + ) { + float ax = lowerInterval; + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = upperInterval - ax; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + score = queryCorrections.additionalCorrection() + additionalCorrection - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += queryCorrections.additionalCorrection() + additionalCorrection - centroidDp; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + } + + /** + * compute the distance between the provided quantized query and the quantized vectors that are + * read from the wrapped {@link IndexInput}. + * + *

The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the + * input is as follows: First the quantized vectors are read from the input,then all the lower + * intervals as floats, then all the upper intervals as floats, then all the target component sums + * as shorts, and finally all the additional corrections as floats. + * + *

The results are stored in the provided scores array. + */ + public void scoreBulk( + byte[] q, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + quantizeScoreBulk(q, BULK_SIZE, scores); + in.readFloats(lowerIntervals, 0, BULK_SIZE); + in.readFloats(upperIntervals, 0, BULK_SIZE); + for (int i = 0; i < BULK_SIZE; i++) { + targetComponentSums[i] = Short.toUnsignedInt(in.readShort()); + } + in.readFloats(additionalCorrections, 0, BULK_SIZE); + for (int i = 0; i < BULK_SIZE; i++) { + scores[i] = score( + queryCorrections, + similarityFunction, + centroidDp, + lowerIntervals[i], + upperIntervals[i], + targetComponentSums[i], + additionalCorrections[i], + scores[i] + ); + } + } +} diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java index e541c10e145bf..7f4e62f156a36 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -9,6 +9,9 @@ package org.elasticsearch.simdvec.internal.vectorization; +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; import java.util.Objects; public abstract class ESVectorizationProvider { @@ -24,6 +27,9 @@ public static ESVectorizationProvider getInstance() { public abstract ESVectorUtilSupport getVectorUtilSupport(); + /** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */ + public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException; + // visible for tests static ESVectorizationProvider lookup(boolean testMode) { return new DefaultESVectorizationProvider(); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java index 758116300aa0f..af5df094659e5 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -9,10 +9,12 @@ package org.elasticsearch.simdvec.internal.vectorization; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Constants; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import java.io.IOException; import java.util.Locale; import java.util.Objects; import java.util.Optional; @@ -32,6 +34,9 @@ public static ESVectorizationProvider getInstance() { public abstract ESVectorUtilSupport getVectorUtilSupport(); + /** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */ + public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException; + // visible for tests static ESVectorizationProvider lookup(boolean testMode) { final int runtimeVersion = Runtime.version().feature(); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java new file mode 100644 index 0000000000000..3b9c8ae5c5212 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -0,0 +1,450 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.simdvec.internal.vectorization; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.LongVector; +import jdk.incubator.vector.ShortVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorSpecies; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteOrder; + +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Panamized scorer for quantized vectors stored as an {@link IndexInput}. */ +public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScorer { + + private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128; + + private static final VectorSpecies LONG_SPECIES_128 = LongVector.SPECIES_128; + private static final VectorSpecies LONG_SPECIES_256 = LongVector.SPECIES_256; + + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; + private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; + + private static final VectorSpecies SHORT_SPECIES_128 = ShortVector.SPECIES_128; + private static final VectorSpecies SHORT_SPECIES_256 = ShortVector.SPECIES_256; + + private static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128; + private static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256; + + private final MemorySegment memorySegment; + + public MemorySegmentES91OSQVectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { + super(in, dimensions); + this.memorySegment = memorySegment; + } + + @Override + public long quantizeScore(byte[] q) throws IOException { + assert q.length == length * 4; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + return quantizeScore256(q); + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + return quantizeScore128(q); + } + } + return super.quantizeScore(q); + } + + private long quantizeScore256(byte[] q) throws IOException { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(length); + var sum0 = LongVector.zero(LONG_SPECIES_256); + var sum1 = LongVector.zero(LONG_SPECIES_256); + var sum2 = LongVector.zero(LONG_SPECIES_256); + var sum3 = LongVector.zero(LONG_SPECIES_256); + for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + + if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) { + var sum0 = LongVector.zero(LONG_SPECIES_128); + var sum1 = LongVector.zero(LONG_SPECIES_128); + var sum2 = LongVector.zero(LONG_SPECIES_128); + var sum3 = LongVector.zero(LONG_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + // tail as bytes + in.seek(offset); + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + private long quantizeScore128(byte[] q) throws IOException { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + + var sum0 = IntVector.zero(INT_SPECIES_128); + var sum1 = IntVector.zero(INT_SPECIES_128); + var sum2 = IntVector.zero(INT_SPECIES_128); + var sum3 = IntVector.zero(INT_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { + var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts(); + sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + // tail as bytes + in.seek(offset); + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF); + subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF); + subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF); + subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + @Override + public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { + assert q.length == length * 4; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + quantizeScore256Bulk(q, count, scores); + return; + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + quantizeScore128Bulk(q, count, scores); + return; + } + } + super.quantizeScoreBulk(q, count, scores); + } + + private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException { + for (int iter = 0; iter < count; iter++) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + + var sum0 = IntVector.zero(INT_SPECIES_128); + var sum1 = IntVector.zero(INT_SPECIES_128); + var sum2 = IntVector.zero(INT_SPECIES_128); + var sum3 = IntVector.zero(INT_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { + var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts(); + sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + // tail as bytes + in.seek(offset); + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF); + subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF); + subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF); + subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF); + } + scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + } + + private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IOException { + for (int iter = 0; iter < count; iter++) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(length); + var sum0 = LongVector.zero(LONG_SPECIES_256); + var sum1 = LongVector.zero(LONG_SPECIES_256); + var sum2 = LongVector.zero(LONG_SPECIES_256); + var sum3 = LongVector.zero(LONG_SPECIES_256); + for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + + if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) { + var sum0 = LongVector.zero(LONG_SPECIES_128); + var sum1 = LongVector.zero(LONG_SPECIES_128); + var sum2 = LongVector.zero(LONG_SPECIES_128); + var sum3 = LongVector.zero(LONG_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + // tail as bytes + in.seek(offset); + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); + } + scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + } + + @Override + public void scoreBulk( + byte[] q, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + assert q.length == length * 4; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + score256Bulk(q, queryCorrections, similarityFunction, centroidDp, scores); + return; + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + score128Bulk(q, queryCorrections, similarityFunction, centroidDp, scores); + return; + } + } + super.scoreBulk(q, queryCorrections, similarityFunction, centroidDp, scores); + } + + private void score128Bulk( + byte[] q, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + quantizeScore128Bulk(q, BULK_SIZE, scores); + int limit = FLOAT_SPECIES_128.loopBound(BULK_SIZE); + int i = 0; + long offset = in.getFilePointer(); + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + for (; i < limit; i += FLOAT_SPECIES_128.length()) { + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var lx = FloatVector.fromMemorySegment( + FLOAT_SPECIES_128, + memorySegment, + offset + 4 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ).sub(ax); + var targetComponentSums = ShortVector.fromMemorySegment( + SHORT_SPECIES_128, + memorySegment, + offset + 8 * BULK_SIZE + i * Short.BYTES, + ByteOrder.LITTLE_ENDIAN + ).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0); + var additionalCorrections = FloatVector.fromMemorySegment( + FLOAT_SPECIES_128, + memorySegment, + offset + 10 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ); + var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i); + // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * + // qcDist; + var res1 = ax.mul(ay).mul(dimensions); + var res2 = lx.mul(ay).mul(targetComponentSums); + var res3 = ax.mul(ly).mul(y1); + var res4 = lx.mul(ly).mul(qcDist); + var res = res1.add(res2).add(res3).add(res4); + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + res = res.mul(-2).add(additionalCorrections).add(queryCorrections.additionalCorrection()).add(1f); + res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0); + res.intoArray(scores, i); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + res = res.add(queryCorrections.additionalCorrection()).add(additionalCorrections).sub(centroidDp); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res.intoArray(scores, i); + // not sure how to do it better + for (int j = 0; j < FLOAT_SPECIES_128.length(); j++) { + scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); + } + } else { + res = res.add(1f).mul(0.5f).max(0); + res.intoArray(scores, i); + } + } + } + in.seek(offset + 14L * BULK_SIZE); + } + + private void score256Bulk( + byte[] q, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + quantizeScore256Bulk(q, BULK_SIZE, scores); + int limit = FLOAT_SPECIES_256.loopBound(BULK_SIZE); + int i = 0; + long offset = in.getFilePointer(); + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + for (; i < limit; i += FLOAT_SPECIES_256.length()) { + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var lx = FloatVector.fromMemorySegment( + FLOAT_SPECIES_256, + memorySegment, + offset + 4 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ).sub(ax); + var targetComponentSums = ShortVector.fromMemorySegment( + SHORT_SPECIES_256, + memorySegment, + offset + 8 * BULK_SIZE + i * Short.BYTES, + ByteOrder.LITTLE_ENDIAN + ).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0); + var additionalCorrections = FloatVector.fromMemorySegment( + FLOAT_SPECIES_256, + memorySegment, + offset + 10 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ); + var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i); + // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * + // qcDist; + var res1 = ax.mul(ay).mul(dimensions); + var res2 = lx.mul(ay).mul(targetComponentSums); + var res3 = ax.mul(ly).mul(y1); + var res4 = lx.mul(ly).mul(qcDist); + var res = res1.add(res2).add(res3).add(res4); + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + res = res.mul(-2).add(additionalCorrections).add(queryCorrections.additionalCorrection()).add(1f); + res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0); + res.intoArray(scores, i); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + res = res.add(queryCorrections.additionalCorrection()).add(additionalCorrections).sub(centroidDp); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res.intoArray(scores, i); + // not sure how to do it better + for (int j = 0; j < FLOAT_SPECIES_256.length(); j++) { + scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); + } + } else { + res = res.add(1f).mul(0.5f).max(0); + res.intoArray(scores, i); + } + } + } + in.seek(offset + 14L * BULK_SIZE); + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java index 62d25d79487ed..c409be4fb37d8 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java @@ -9,6 +9,12 @@ package org.elasticsearch.simdvec.internal.vectorization; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; + final class PanamaESVectorizationProvider extends ESVectorizationProvider { private final ESVectorUtilSupport vectorUtilSupport; @@ -21,4 +27,15 @@ final class PanamaESVectorizationProvider extends ESVectorizationProvider { public ESVectorUtilSupport getVectorUtilSupport() { return vectorUtilSupport; } + + @Override + public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException { + if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS && input instanceof MemorySegmentAccessInput msai) { + MemorySegment ms = msai.segmentSliceOrNull(0, input.length()); + if (ms != null) { + return new MemorySegmentES91OSQVectorsScorer(input, dimension, ms); + } + } + return new ES91OSQVectorsScorer(input, dimension); + } } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java new file mode 100644 index 0000000000000..81172872d47d8 --- /dev/null +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal.vectorization; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +public class ES91OSQVectorScorerTests extends BaseVectorizationTests { + public void testQuantizeScore() throws Exception { + final int dimensions = random().nextInt(1, 2000); + final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8; + final int numVectors = random().nextInt(1, 100); + final byte[] vector = new byte[length]; + try (Directory dir = new MMapDirectory(createTempDir())) { + try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) { + for (int i = 0; i < numVectors; i++) { + random().nextBytes(vector); + out.writeBytes(vector, 0, length); + } + } + final byte[] query = new byte[4 * length]; + random().nextBytes(query); + try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) { + // Work on a slice that has just the right number of bytes to make the test fail with an + // index-out-of-bounds in case the implementation reads more than the allowed number of + // padding bytes. + final IndexInput slice = in.slice("test", 0, (long) length * numVectors); + final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions); + final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions); + for (int i = 0; i < numVectors; i++) { + assertEquals(defaultScorer.quantizeScore(query), panamaScorer.quantizeScore(query)); + assertEquals(in.getFilePointer(), slice.getFilePointer()); + } + assertEquals((long) length * numVectors, slice.getFilePointer()); + } + } + } + + public void testScore() throws Exception { + final int dimensions = random().nextInt(1, 2000); + final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8; + final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10); + final byte[] vector = new byte[length + 14]; + int padding = random().nextInt(100); + try (Directory dir = new MMapDirectory(createTempDir())) { + try (IndexOutput out = dir.createOutput("testScore.bin", IOContext.DEFAULT)) { + for (int i = 0; i < padding; i++) { + out.writeByte((byte) random().nextInt()); + } + for (int i = 0; i < numVectors; i++) { + random().nextBytes(vector); + out.writeBytes(vector, 0, length + 14); + } + } + final byte[] query = new byte[4 * length]; + random().nextBytes(query); + OptimizedScalarQuantizer.QuantizationResult result = new OptimizedScalarQuantizer.QuantizationResult( + random().nextFloat(), + random().nextFloat(), + random().nextFloat(), + Short.toUnsignedInt((short) random().nextInt()) + ); + final float centroidDp = random().nextFloat(); + final float[] scores1 = new float[ES91OSQVectorsScorer.BULK_SIZE]; + final float[] scores2 = new float[ES91OSQVectorsScorer.BULK_SIZE]; + for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { + try (IndexInput in = dir.openInput("testScore.bin", IOContext.DEFAULT)) { + in.seek(padding); + assertEquals(in.length(), padding + (long) numVectors * (length + 14)); + // Work on a slice that has just the right number of bytes to make the test fail with an + // index-out-of-bounds in case the implementation reads more than the allowed number of + // padding bytes. + for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) { + final IndexInput slice = in.slice( + "test", + in.getFilePointer(), + (long) (length + 14) * ES91OSQVectorsScorer.BULK_SIZE + ); + final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions); + final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions); + defaultScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores1); + panamaScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores2); + assertArrayEquals(scores1, scores2, 1e-2f); + assertEquals(((long) (ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), slice.getFilePointer()); + assertEquals(padding + ((long) (i + ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), in.getFilePointer()); + } + } + } + } + } +} From f134ad3c7a8759e3b8d622a1bfbfa77f64314a13 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 23 Apr 2025 10:42:27 -0400 Subject: [PATCH 2/3] fixing headers --- .../internal/vectorization/ES91OSQVectorsScorer.java | 8 +++++--- .../vectorization/MemorySegmentES91OSQVectorsScorer.java | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java index 66a8e5b89855d..839e5f29a1148 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java @@ -1,8 +1,10 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ package org.elasticsearch.simdvec.internal.vectorization; diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java index 3b9c8ae5c5212..0bf3b8da22d2c 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -1,8 +1,10 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ package org.elasticsearch.simdvec.internal.vectorization; From 9353495d4781f9d6c0fb00ef6940c0eab8aa18ab Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 23 Apr 2025 16:07:20 -0400 Subject: [PATCH 3/3] fixing tests --- .../ES91OSQVectorScorerTests.java | 43 ++++++++++++++----- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java index 81172872d47d8..53b14ae4910c0 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java @@ -17,7 +17,10 @@ import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import static org.hamcrest.Matchers.lessThan; + public class ES91OSQVectorScorerTests extends BaseVectorizationTests { + public void testQuantizeScore() throws Exception { final int dimensions = random().nextInt(1, 2000); final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8; @@ -49,28 +52,38 @@ public void testQuantizeScore() throws Exception { } public void testScore() throws Exception { - final int dimensions = random().nextInt(1, 2000); + final int maxDims = 512; + final int dimensions = random().nextInt(1, maxDims); final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8; final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10); - final byte[] vector = new byte[length + 14]; + final byte[] vector = new byte[length]; int padding = random().nextInt(100); + byte[] paddingBytes = new byte[padding]; try (Directory dir = new MMapDirectory(createTempDir())) { try (IndexOutput out = dir.createOutput("testScore.bin", IOContext.DEFAULT)) { - for (int i = 0; i < padding; i++) { - out.writeByte((byte) random().nextInt()); - } + random().nextBytes(paddingBytes); + out.writeBytes(paddingBytes, 0, padding); for (int i = 0; i < numVectors; i++) { random().nextBytes(vector); - out.writeBytes(vector, 0, length + 14); + out.writeBytes(vector, 0, length); + float lower = random().nextFloat(); + float upper = random().nextFloat() + lower / 2; + float additionalCorrection = random().nextFloat(); + int targetComponentSum = randomIntBetween(0, dimensions / 2); + out.writeInt(Float.floatToIntBits(lower)); + out.writeInt(Float.floatToIntBits(upper)); + out.writeShort((short) targetComponentSum); + out.writeInt(Float.floatToIntBits(additionalCorrection)); } } final byte[] query = new byte[4 * length]; random().nextBytes(query); + float lower = random().nextFloat(); OptimizedScalarQuantizer.QuantizationResult result = new OptimizedScalarQuantizer.QuantizationResult( + lower, + random().nextFloat() + lower / 2, random().nextFloat(), - random().nextFloat(), - random().nextFloat(), - Short.toUnsignedInt((short) random().nextInt()) + randomIntBetween(0, dimensions * 2) ); final float centroidDp = random().nextFloat(); final float[] scores1 = new float[ES91OSQVectorsScorer.BULK_SIZE]; @@ -92,7 +105,17 @@ public void testScore() throws Exception { final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions); defaultScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores1); panamaScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores2); - assertArrayEquals(scores1, scores2, 1e-2f); + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + if (scores1[j] > (maxDims * Short.MAX_VALUE)) { + int diff = (int) (scores1[j] - scores2[j]); + assertThat("defaultScores: " + scores1[j] + " bulkScores: " + scores2[j], Math.abs(diff), lessThan(65)); + } else if (scores1[j] > (maxDims * Byte.MAX_VALUE)) { + int diff = (int) (scores1[j] - scores2[j]); + assertThat("defaultScores: " + scores1[j] + " bulkScores: " + scores2[j], Math.abs(diff), lessThan(9)); + } else { + assertEquals(scores1[j], scores2[j], 1e-2f); + } + } assertEquals(((long) (ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), slice.getFilePointer()); assertEquals(padding + ((long) (i + ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), in.getFilePointer()); }