Skip to content

Commit

Permalink
iter
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent committed Jun 19, 2024
1 parent 4dc9751 commit 4a4a189
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import java.io.IOException;

import static org.apache.lucene.util.BitUtil.VH_NATIVE_LONG;
import static org.elasticsearch.script.VectorScoreScriptUtils.andBitCount;

class ES815BitFlatVectorsFormat extends FlatVectorsFormat {

Expand Down Expand Up @@ -96,23 +96,6 @@ public RandomVectorScorer getRandomVectorScorer(
}
}

static int andBitCount(byte[] a, byte[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
} else {
int distance = 0;
int i = 0;
for (int upperBound = a.length & -8; i < upperBound; i += 8) {
distance += Long.bitCount((long) VH_NATIVE_LONG.get(a, i) & (long) VH_NATIVE_LONG.get(b, i));
}
while (i < a.length) {
distance += Integer.bitCount((a[i] & b[i]) & 255);
++i;
}
return distance;
}
}

static float hammingScore(byte[] a, byte[] b) {
return ((a.length * Byte.SIZE) - VectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) {
return switch (elementType) {
case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims);
case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims);
case BIT -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims / 8);
};
} else {
BinaryDocValues values = DocValues.getBinary(reader, field);
return switch (elementType) {
case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims);
case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion);
case BIT -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims / 8);
};
}
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,28 @@
import java.util.HexFormat;
import java.util.List;

import static org.apache.lucene.util.BitUtil.VH_NATIVE_LONG;

public class VectorScoreScriptUtils {

public static final NodeFeature HAMMING_DISTANCE_FUNCTION = new NodeFeature("script.hamming");

public static int andBitCount(byte[] a, byte[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
int distance = 0;
int i = 0;
for (int upperBound = a.length & -8; i < upperBound; i += 8) {
distance += Long.bitCount((long) VH_NATIVE_LONG.get(a, i) & (long) VH_NATIVE_LONG.get(b, i));
}
while (i < a.length) {
distance += Integer.bitCount((a[i] & b[i]) & 255);
++i;
}
return distance;
}

public static class DenseVectorFunction {
protected final ScoreScript scoreScript;
protected final DenseVectorDocValuesField field;
Expand Down Expand Up @@ -148,6 +166,32 @@ public double l1norm() {
}
}

public static class BitLNNorm extends ByteDenseVectorFunction implements L2NormInterface, L1NormInterface, HammingDistanceInterface {

public BitLNNorm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
super(scoreScript, field, queryVector);
}

public BitLNNorm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
super(scoreScript, field, queryVector);
}

public double l2norm() {
setNextVector();
return field.get().hamming(queryVector);
}

public double l1norm() {
setNextVector();
return field.get().hamming(queryVector);
}

public int hamming() {
setNextVector();
return field.get().hamming(queryVector);
}
}

public static class FloatL1Norm extends FloatDenseVectorFunction implements L1NormInterface {

public FloatL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
Expand Down Expand Up @@ -183,6 +227,15 @@ public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
}
throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
}
case BIT -> {
if (queryVector instanceof List) {
yield new BitLNNorm(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) {
byte[] parsedQueryVector = HexFormat.of().parseHex(s);
yield new BitLNNorm(scoreScript, field, parsedQueryVector);
}
throw new IllegalArgumentException("Unsupported input object for bit vectors: " + queryVector.getClass().getName());
}
};
}

Expand Down Expand Up @@ -219,8 +272,8 @@ public static final class Hamming {
@SuppressWarnings("unchecked")
public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BYTE) {
throw new IllegalArgumentException("hamming distance is only supported for byte vectors");
if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors");
}
if (queryVector instanceof List) {
function = new ByteHammingDistance(scoreScript, field, (List<Number>) queryVector);
Expand Down Expand Up @@ -293,6 +346,15 @@ public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
}
throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
}
case BIT -> {
if (queryVector instanceof List) {
yield new BitLNNorm(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) {
byte[] parsedQueryVector = HexFormat.of().parseHex(s);
yield new BitLNNorm(scoreScript, field, parsedQueryVector);
}
throw new IllegalArgumentException("Unsupported input object for bit vectors: " + queryVector.getClass().getName());
}
};
}

Expand Down Expand Up @@ -322,6 +384,27 @@ public double dotProduct() {
}
}

public static class BitDotProduct extends ByteDenseVectorFunction implements DotProductInterface, CosineSimilarityInterface {

public BitDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
super(scoreScript, field, queryVector);
}

public BitDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
super(scoreScript, field, queryVector);
}

public double dotProduct() {
setNextVector();
return field.get().andBitCount(queryVector);
}

public double cosineSimilarity() {
setNextVector();
return field.get().andBitCount(queryVector);
}
}

public static class FloatDotProduct extends FloatDenseVectorFunction implements DotProductInterface {

public FloatDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
Expand Down Expand Up @@ -357,6 +440,15 @@ public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName)
}
throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
}
case BIT -> {
if (queryVector instanceof List) {
yield new BitDotProduct(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) {
byte[] parsedQueryVector = HexFormat.of().parseHex(s);
yield new BitDotProduct(scoreScript, field, parsedQueryVector);
}
throw new IllegalArgumentException("Unsupported input object for bit vectors: " + queryVector.getClass().getName());
}
};
}

Expand Down Expand Up @@ -421,6 +513,15 @@ public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fiel
}
throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
}
case BIT -> {
if (queryVector instanceof List) {
yield new BitDotProduct(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) {
byte[] parsedQueryVector = HexFormat.of().parseHex(s);
yield new BitDotProduct(scoreScript, field, parsedQueryVector);
}
throw new IllegalArgumentException("Unsupported input object for bit vectors: " + queryVector.getClass().getName());
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ public double l1Norm(List<Number> queryVector) {
return l1norm;
}

@Override
public int andBitCount(byte[] queryVector) {
throw new UnsupportedOperationException("bitAnd is not supported for float vectors");
}

@Override
public int andBitCount(List<Number> queryVector) {
throw new UnsupportedOperationException("bitAnd is not supported for float vectors");
}

@Override
public int hamming(byte[] queryVector) {
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.script.VectorScoreScriptUtils;

import java.nio.ByteBuffer;
import java.util.List;
Expand Down Expand Up @@ -100,6 +101,20 @@ public double l1Norm(List<Number> queryVector) {
return result;
}

@Override
public int andBitCount(byte[] queryVector) {
return VectorScoreScriptUtils.andBitCount(queryVector, vectorValue);
}

@Override
public int andBitCount(List<Number> queryVector) {
int distance = 0;
for (int i = 0; i < queryVector.size(); i++) {
distance += Integer.bitCount((queryVector.get(i).intValue() & vectorValue[i]) & 0xFF);
}
return distance;
}

@Override
public int hamming(byte[] queryVector) {
return VectorUtil.xorBitCount(queryVector, vectorValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.script.VectorScoreScriptUtils;

import java.util.List;

Expand Down Expand Up @@ -101,6 +102,20 @@ public double l1Norm(List<Number> queryVector) {
return result;
}

@Override
public int andBitCount(byte[] queryVector) {
return VectorScoreScriptUtils.andBitCount(queryVector, docVector);
}

@Override
public int andBitCount(List<Number> queryVector) {
int distance = 0;
for (int i = 0; i < queryVector.size(); i++) {
distance += Integer.bitCount((queryVector.get(i).intValue() & docVector[i]) & 0xFF);
}
return distance;
}

@Override
public int hamming(byte[] queryVector) {
return VectorUtil.xorBitCount(queryVector, docVector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ default double l1Norm(Object queryVector) {
throw new IllegalArgumentException(badQueryVectorType(queryVector));
}

int andBitCount(byte[] queryVector);

int andBitCount(List<Number> queryVector);

@SuppressWarnings("unchecked")
default int andBitCount(Object queryVector) {
if (queryVector instanceof List<?> list) {
checkDimensions(getDims(), list.size());
return andBitCount((List<Number>) list);
}
if (queryVector instanceof byte[] bytes) {
checkDimensions(getDims(), bytes.length);
return andBitCount(bytes);
}

throw new IllegalArgumentException(badQueryVectorType(queryVector));
}

int hamming(byte[] queryVector);

int hamming(List<Number> queryVector);
Expand Down Expand Up @@ -248,6 +266,16 @@ public double l1Norm(List<Number> queryVector) {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
}

@Override
public int andBitCount(byte[] queryVector) {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
}

@Override
public int andBitCount(List<Number> queryVector) {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
}

@Override
public int hamming(byte[] queryVector) {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ public double l1Norm(List<Number> queryVector) {
return result;
}

@Override
public int andBitCount(byte[] queryVector) {
throw new UnsupportedOperationException("bitAnd is not supported for float vectors");
}

@Override
public int andBitCount(List<Number> queryVector) {
throw new UnsupportedOperationException("bitAnd is not supported for float vectors");
}

@Override
public int hamming(byte[] queryVector) {
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) {
}
yield floats;
}
case BIT -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions() / 8);
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
// The field should always be resolved to the concrete field
Query knnVectorQueryBuilt = switch (elementType()) {
case BYTE -> new ESKnnByteVectorQuery(
case BYTE, BIT -> new ESKnnByteVectorQuery(
VECTOR_FIELD,
queryBuilder.queryVector().asByteVector(),
queryBuilder.numCands(),
Expand Down

0 comments on commit 4a4a189

Please sign in to comment.