Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.benchmark.vector;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

import java.io.IOException;
import java.nio.file.Files;
import java.util.Random;
import java.util.concurrent.TimeUnit;

@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
// first iteration is complete garbage, so make sure we really warmup
@Warmup(iterations = 4, time = 1)
// real iterations. not useful to spend tons of time here, better to fork more
@Measurement(iterations = 5, time = 1)
// engage some noise reduction
@Fork(value = 1)
public class OSQScorerBenchmark {

static {
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}

@Param({ "1024" })
int dims;

int length;

int numVectors = ES91OSQVectorsScorer.BULK_SIZE * 10;
int numQueries = 10;

byte[][] binaryVectors;
byte[][] binaryQueries;
OptimizedScalarQuantizer.QuantizationResult result;
float centroidDp;

byte[] scratch;
ES91OSQVectorsScorer scorer;

IndexInput in;

float[] scratchScores;
float[] corrections;

@Setup
public void setup() throws IOException {
Random random = new Random(123);

this.length = OptimizedScalarQuantizer.discretize(dims, 64) / 8;

binaryVectors = new byte[numVectors][length];
for (byte[] binaryVector : binaryVectors) {
random.nextBytes(binaryVector);
}

Directory dir = new MMapDirectory(Files.createTempDirectory("vectorData"));
IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT);
byte[] correctionBytes = new byte[14 * ES91OSQVectorsScorer.BULK_SIZE];
for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
out.writeBytes(binaryVectors[i + j], 0, binaryVectors[i + j].length);
}
random.nextBytes(correctionBytes);
out.writeBytes(correctionBytes, 0, correctionBytes.length);
}
out.close();
in = dir.openInput("vectors", IOContext.DEFAULT);

binaryQueries = new byte[numVectors][4 * length];
for (byte[] binaryVector : binaryVectors) {
random.nextBytes(binaryVector);
}
result = new OptimizedScalarQuantizer.QuantizationResult(
random.nextFloat(),
random.nextFloat(),
random.nextFloat(),
Short.toUnsignedInt((short) random.nextInt())
);
centroidDp = random.nextFloat();

scratch = new byte[length];
scorer = ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(in, dims);
scratchScores = new float[16];
corrections = new float[3];
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void scoreFromArray(Blackhole bh) throws IOException {
for (int j = 0; j < numQueries; j++) {
in.seek(0);
for (int i = 0; i < numVectors; i++) {
in.readBytes(scratch, 0, length);
float qDist = VectorUtil.int4BitDotProduct(binaryQueries[j], scratch);
in.readFloats(corrections, 0, corrections.length);
int addition = Short.toUnsignedInt(in.readShort());
float score = scorer.score(
result,
VectorSimilarityFunction.EUCLIDEAN,
centroidDp,
corrections[0],
corrections[1],
addition,
corrections[2],
qDist
);
bh.consume(score);
}
}
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException {
for (int j = 0; j < numQueries; j++) {
in.seek(0);
for (int i = 0; i < numVectors; i++) {
float qDist = scorer.quantizeScore(binaryQueries[j]);
in.readFloats(corrections, 0, corrections.length);
int addition = Short.toUnsignedInt(in.readShort());
float score = scorer.score(
result,
VectorSimilarityFunction.EUCLIDEAN,
centroidDp,
corrections[0],
corrections[1],
addition,
corrections[2],
qDist
);
bh.consume(score);
}
}
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh) throws IOException {
for (int j = 0; j < numQueries; j++) {
in.seek(0);
for (int i = 0; i < numVectors; i += 16) {
scorer.quantizeScoreBulk(binaryQueries[j], ES91OSQVectorsScorer.BULK_SIZE, scratchScores);
for (int k = 0; k < ES91OSQVectorsScorer.BULK_SIZE; k++) {
in.readFloats(corrections, 0, corrections.length);
int addition = Short.toUnsignedInt(in.readShort());
float score = scorer.score(
result,
VectorSimilarityFunction.EUCLIDEAN,
centroidDp,
corrections[0],
corrections[1],
addition,
corrections[2],
scratchScores[k]
);
bh.consume(score);
}
}
}
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void scoreFromMemorySegmentAllBulk(Blackhole bh) throws IOException {
for (int j = 0; j < numQueries; j++) {
in.seek(0);
for (int i = 0; i < numVectors; i += 16) {
scorer.scoreBulk(binaryQueries[j], result, VectorSimilarityFunction.EUCLIDEAN, centroidDp, scratchScores);
bh.consume(scratchScores);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

package org.elasticsearch.simdvec.internal.vectorization;

import org.apache.lucene.store.IndexInput;

import java.io.IOException;

final class DefaultESVectorizationProvider extends ESVectorizationProvider {
private final ESVectorUtilSupport vectorUtilSupport;

Expand All @@ -20,4 +24,9 @@ final class DefaultESVectorizationProvider extends ESVectorizationProvider {
public ESVectorUtilSupport getVectorUtilSupport() {
return vectorUtilSupport;
}

@Override
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
return new ES91OSQVectorsScorer(input, dimension);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.simdvec.internal.vectorization;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;

import java.io.IOException;

import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;

/** Scorer for quantized vectors stored as an {@link IndexInput}. */
public class ES91OSQVectorsScorer {

public static final int BULK_SIZE = 16;

protected static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);

/** The wrapper {@link IndexInput}. */
protected final IndexInput in;

protected final int length;
protected final int dimensions;

protected final float[] lowerIntervals = new float[BULK_SIZE];
protected final float[] upperIntervals = new float[BULK_SIZE];
protected final int[] targetComponentSums = new int[BULK_SIZE];
protected final float[] additionalCorrections = new float[BULK_SIZE];

/** Sole constructor, called by sub-classes. */
public ES91OSQVectorsScorer(IndexInput in, int dimensions) {
this.in = in;
this.dimensions = dimensions;
this.length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
}

/**
* compute the quantize distance between the provided quantized query and the quantized vector
* that is read from the wrapped {@link IndexInput}.
*/
public long quantizeScore(byte[] q) throws IOException {
assert q.length == length * 4;
final int size = length;
long subRet0 = 0;
long subRet1 = 0;
long subRet2 = 0;
long subRet3 = 0;
int r = 0;
for (final int upperBound = size & -Long.BYTES; r < upperBound; r += Long.BYTES) {
final long value = in.readLong();
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r) & value);
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + size) & value);
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 2 * size) & value);
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 3 * size) & value);
}
for (final int upperBound = size & -Integer.BYTES; r < upperBound; r += Integer.BYTES) {
final int value = in.readInt();
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r) & value);
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + size) & value);
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 2 * size) & value);
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 3 * size) & value);
}
for (; r < size; r++) {
final byte value = in.readByte();
subRet0 += Integer.bitCount((q[r] & value) & 0xFF);
subRet1 += Integer.bitCount((q[r + size] & value) & 0xFF);
subRet2 += Integer.bitCount((q[r + 2 * size] & value) & 0xFF);
subRet3 += Integer.bitCount((q[r + 3 * size] & value) & 0xFF);
}
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
}

/**
* compute the quantize distance between the provided quantized query and the quantized vectors
* that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is
* determined by {code count} and the results are stored in the provided {@code scores} array.
*/
public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException {
for (int i = 0; i < count; i++) {
scores[i] = quantizeScore(q);
}
}

/**
* Computes the score by applying the necessary corrections to the provided quantized distance.
*/
public float score(
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
VectorSimilarityFunction similarityFunction,
float centroidDp,
float lowerInterval,
float upperInterval,
int targetComponentSum,
float additionalCorrection,
float qcDist
) {
float ax = lowerInterval;
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
float lx = upperInterval - ax;
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
// For euclidean, we need to invert the score and apply the additional correction, which is
// assumed to be the squared l2norm of the centroid centered vectors.
if (similarityFunction == EUCLIDEAN) {
score = queryCorrections.additionalCorrection() + additionalCorrection - 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score += queryCorrections.additionalCorrection() + additionalCorrection - centroidDp;
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
}
return Math.max((1f + score) / 2f, 0);
}
}

/**
* compute the distance between the provided quantized query and the quantized vectors that are
* read from the wrapped {@link IndexInput}.
*
* <p>The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the
* input is as follows: First the quantized vectors are read from the input,then all the lower
* intervals as floats, then all the upper intervals as floats, then all the target component sums
* as shorts, and finally all the additional corrections as floats.
*
* <p>The results are stored in the provided scores array.
*/
public void scoreBulk(
byte[] q,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
VectorSimilarityFunction similarityFunction,
float centroidDp,
float[] scores
) throws IOException {
quantizeScoreBulk(q, BULK_SIZE, scores);
in.readFloats(lowerIntervals, 0, BULK_SIZE);
in.readFloats(upperIntervals, 0, BULK_SIZE);
for (int i = 0; i < BULK_SIZE; i++) {
targetComponentSums[i] = Short.toUnsignedInt(in.readShort());
}
in.readFloats(additionalCorrections, 0, BULK_SIZE);
for (int i = 0; i < BULK_SIZE; i++) {
scores[i] = score(
queryCorrections,
similarityFunction,
centroidDp,
lowerIntervals[i],
upperIntervals[i],
targetComponentSums[i],
additionalCorrections[i],
scores[i]
);
}
}
}
Loading