Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precompute vector length on indexing #45390

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/reference/mapping/types/dense-vector.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ PUT my_index/_doc/2

Internally, each document's dense vector is encoded as a binary
doc value. Its size in bytes is equal to
`4 * dims`, where `dims`—the number of the vector's dimensions.
`4 * dims + 4`, where `dims`—the number of the vector's dimensions.
2 changes: 1 addition & 1 deletion docs/reference/mapping/types/sparse-vector.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ PUT my_index/_doc/2

Internally, each document's sparse vector is encoded as a binary
doc value. Its size in bytes is equal to
`6 * NUMBER_OF_DIMENSIONS`, where `NUMBER_OF_DIMENSIONS` -
`6 * NUMBER_OF_DIMENSIONS + 4`, where `NUMBER_OF_DIMENSIONS` -
number of the vector's dimensions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.script.ExplainableScoreScript;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.Version;

import java.io.IOException;
import java.util.Objects;
Expand Down Expand Up @@ -52,22 +53,24 @@ public float score() {

private final int shardId;
private final String indexName;

private final Version indexVersion;

public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script) {
super(CombineFunction.REPLACE);
this.sScript = sScript;
this.script = script;
this.indexName = null;
this.shardId = -1;
this.indexVersion = null;
}

public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId) {
public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) {
super(CombineFunction.REPLACE);
this.sScript = sScript;
this.script = script;
this.indexName = indexName;
this.shardId = shardId;
this.indexVersion = indexVersion;
}

@Override
Expand All @@ -77,6 +80,7 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx
leafScript.setScorer(scorer);
leafScript._setIndexName(indexName);
leafScript._setShard(shardId);
leafScript._setIndexVersion(indexVersion);
return new LeafScoreFunction() {
@Override
public double score(int docId, float subQueryScore) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ protected ScoreFunction doToFunction(QueryShardContext context) {
try {
ScoreScript.Factory factory = context.getScriptService().compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
return new ScriptScoreFunction(script, searchScript, context.index().getName(), context.getShardId());
return new ScriptScoreFunction(script, searchScript,
context.index().getName(), context.getShardId(), context.indexVersionCreated());
} catch (Exception e) {
throw new QueryShardException(context, "script_score: the script could not be loaded", e);
}
Expand Down
22 changes: 22 additions & 0 deletions server/src/main/java/org/elasticsearch/script/ScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.search.lookup.LeafSearchLookup;
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.Version;

import java.io.IOException;
import java.io.UncheckedIOException;
Expand Down Expand Up @@ -56,6 +57,7 @@ public abstract class ScoreScript {
private int docId;
private int shardId = -1;
private String indexName = null;
private Version indexVersion = null;

public ScoreScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
// null check needed b/c of expression engine subclass
Expand Down Expand Up @@ -155,6 +157,19 @@ public String _getIndex() {
}
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
* It is only used within predefined painless functions.
* @return index version or throws an exception if the index version is not set up for this script instance
*/
public Version _getIndexVersion() {
if (indexVersion != null) {
return indexVersion;
} else {
throw new IllegalArgumentException("index version can not be looked up!");
}
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
*/
Expand All @@ -169,6 +184,13 @@ public void _setIndexName(String indexName) {
this.indexName = indexName;
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
*/
public void _setIndexVersion(Version indexVersion) {
this.indexVersion = indexVersion;
}


/** A factory to construct {@link ScoreScript} instances. */
public interface LeafFactory {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ setup:
index: test-index
id: 2
body:
my_dense_vector: [10.9, 10.9, 10.9]
my_dense_vector: [10.5, 10.9, 10.4]

- do:
indices.refresh: {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser.Token;
Expand Down Expand Up @@ -180,8 +181,9 @@ public void parse(ParseContext context) throws IOException {

// encode array of floats as array of integers and store into buf
// this code is here and not int the VectorEncoderDecoder so not to create extra arrays
byte[] buf = new byte[dims * INT_BYTES];
byte[] buf = indexCreatedVersion.onOrAfter(Version.V_7_4_0) ? new byte[dims * INT_BYTES + INT_BYTES] : new byte[dims * INT_BYTES];
int offset = 0;
double dotProduct = 0f;
int dim = 0;
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
if (dim++ >= dims) {
Expand All @@ -195,13 +197,23 @@ public void parse(ParseContext context) throws IOException {
buf[offset++] = (byte) (intValue >> 16);
buf[offset++] = (byte) (intValue >> 8);
buf[offset++] = (byte) intValue;
dotProduct += value * value;
}
if (dim != dims) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" +
context.sourceToParse().id() + "] has number of dimensions [" + dim +
"] less than defined in the mapping [" + dims +"]");
}
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf, 0, offset));
if (indexCreatedVersion.onOrAfter(Version.V_7_4_0)) {
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
int vectorMagnitudeIntValue = Float.floatToIntBits(vectorMagnitude);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 24);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 16);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 8);
buf[offset++] = (byte) vectorMagnitudeIntValue;
}
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf));
if (context.doc().getByKey(fieldType().name()) != null) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() +
"] doesn't not support indexing multiple values for the same field in the same document");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ public void parse(ParseContext context) throws IOException {
}
}

BytesRef br = VectorEncoderDecoder.encodeSparseVector(dims, values, dimCount);
BytesRef br = VectorEncoderDecoder.encodeSparseVector(indexCreatedVersion, dims, values, dimCount);
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), br);
context.doc().addWithKey(fieldType().name(), field);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.InPlaceMergeSorter;
import org.elasticsearch.Version;

// static utility functions for encoding and decoding dense_vector and sparse_vector fields
public final class VectorEncoderDecoder {
Expand All @@ -19,76 +20,90 @@ private VectorEncoderDecoder() { }

/**
* Encodes a sparse array represented by values, dims and dimCount into a bytes array - BytesRef
* BytesRef: int[] floats encoded as integers values, 2 bytes for each dimension
* @param values - values of the sparse array
* BytesRef: int[] floats encoded as integers values, 2 bytes for each dimension, length of vector
* @param indexVersion - index version
* @param dims - dims of the sparse array
* @param values - values of the sparse array
* @param dimCount - number of the dimensions, necessary as values and dims are dynamically created arrays,
* and may be over-allocated
* @return BytesRef
*/
public static BytesRef encodeSparseVector(int[] dims, float[] values, int dimCount) {
public static BytesRef encodeSparseVector(Version indexVersion, int[] dims, float[] values, int dimCount) {
Copy link
Contributor

@jtibshirani jtibshirani Aug 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could also cut down on code duplication here. If it helps refactor the two code paths into one, I think it's okay to calculate dotProduct for indexes older than 7.4.0, even though we'll end up throwing the value away. It shouldn't be too much overhead, and we don't expect too many users to be locked into the 7.3 version.

Copy link
Contributor Author

@mayya-sharipova mayya-sharipova Aug 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jtibshirani. Makes sense. Addressed with baddf608

// 1. Sort dims and values
sortSparseDimsValues(dims, values, dimCount);
byte[] buf = new byte[dimCount * (INT_BYTES + SHORT_BYTES)];

// 2. Encode dimensions
// as each dimension is a positive value that doesn't exceed 65535, 2 bytes is enough for encoding it
byte[] buf = indexVersion.onOrAfter(Version.V_7_4_0) ? new byte[dimCount * (INT_BYTES + SHORT_BYTES) + INT_BYTES] :
new byte[dimCount * (INT_BYTES + SHORT_BYTES)];
int offset = 0;
for (int dim = 0; dim < dimCount; dim++) {
buf[offset] = (byte) (dims[dim] >> 8);
buf[offset+1] = (byte) dims[dim];
offset += SHORT_BYTES;
buf[offset++] = (byte) (dims[dim] >> 8);
buf[offset++] = (byte) dims[dim];
}

// 3. Encode values
double dotProduct = 0.0f;
for (int dim = 0; dim < dimCount; dim++) {
int intValue = Float.floatToIntBits(values[dim]);
buf[offset] = (byte) (intValue >> 24);
buf[offset+1] = (byte) (intValue >> 16);
buf[offset+2] = (byte) (intValue >> 8);
buf[offset+3] = (byte) intValue;
offset += INT_BYTES;
buf[offset++] = (byte) (intValue >> 24);
buf[offset++] = (byte) (intValue >> 16);
buf[offset++] = (byte) (intValue >> 8);
buf[offset++] = (byte) intValue;
dotProduct += values[dim] * values[dim];
}

// 4. Encode vector magnitude at the end
if (indexVersion.onOrAfter(Version.V_7_4_0)) {
float vectorMagnitude = (float) Math.sqrt(dotProduct);
int vectorMagnitudeIntValue = Float.floatToIntBits(vectorMagnitude);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 24);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 16);
buf[offset++] = (byte) (vectorMagnitudeIntValue >> 8);
buf[offset++] = (byte) vectorMagnitudeIntValue;
}

return new BytesRef(buf);
}

/**
* Decodes the first part of BytesRef into sparse vector dimensions
* @param indexVersion - index version
* @param vectorBR - sparse vector encoded in BytesRef
*/
public static int[] decodeSparseVectorDims(BytesRef vectorBR) {
public static int[] decodeSparseVectorDims(Version indexVersion, BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES);
int[] dims = new int[dimCount];
int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES) :
vectorBR.length / (INT_BYTES + SHORT_BYTES);
int offset = vectorBR.offset;
int[] dims = new int[dimCount];
for (int dim = 0; dim < dimCount; dim++) {
dims[dim] = ((vectorBR.bytes[offset] & 0xFF) << 8) | (vectorBR.bytes[offset+1] & 0xFF);
offset += SHORT_BYTES;
dims[dim] = ((vectorBR.bytes[offset++] & 0xFF) << 8) | (vectorBR.bytes[offset++] & 0xFF);
}
return dims;
}

/**
* Decodes the second part of the BytesRef into sparse vector values
* @param indexVersion - index version
* @param vectorBR - sparse vector encoded in BytesRef
*/
public static float[] decodeSparseVector(BytesRef vectorBR) {
public static float[] decodeSparseVector(Version indexVersion, BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES);
int offset = vectorBR.offset + SHORT_BYTES * dimCount; //calculate the offset from where values are encoded
int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES) :
vectorBR.length / (INT_BYTES + SHORT_BYTES);
int offset = vectorBR.offset + SHORT_BYTES * dimCount;
float[] vector = new float[dimCount];
for (int dim = 0; dim < dimCount; dim++) {
int intValue = ((vectorBR.bytes[offset] & 0xFF) << 24) |
((vectorBR.bytes[offset+1] & 0xFF) << 16) |
((vectorBR.bytes[offset+2] & 0xFF) << 8) |
(vectorBR.bytes[offset+3] & 0xFF);
int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) |
((vectorBR.bytes[offset++] & 0xFF) << 16) |
((vectorBR.bytes[offset++] & 0xFF) << 8) |
(vectorBR.bytes[offset++] & 0xFF);
vector[dim] = Float.intBitsToFloat(intValue);
offset = offset + INT_BYTES;
}
return vector;
}
Expand Down Expand Up @@ -152,15 +167,16 @@ public void swap(int i, int j) {

/**
* Decodes a BytesRef into an array of floats
* @param indexVersion - index Version
* @param vectorBR - dense vector encoded in BytesRef
*/
public static float[] decodeDenseVector(BytesRef vectorBR) {
public static float[] decodeDenseVector(Version indexVersion, BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
int dimCount = vectorBR.length / INT_BYTES;
float[] vector = new float[dimCount];
int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / INT_BYTES : vectorBR.length/ INT_BYTES;
int offset = vectorBR.offset;
float[] vector = new float[dimCount];
for (int dim = 0; dim < dimCount; dim++) {
int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) |
((vectorBR.bytes[offset++] & 0xFF) << 16) |
Expand All @@ -170,4 +186,32 @@ public static float[] decodeDenseVector(BytesRef vectorBR) {
}
return vector;
}

/**
* Calculates vector magnitude either by
* decoding last 4 bytes of BytesRef into a vector magnitude or calculating it
* @param indexVersion - index Version
* @param vectorBR - vector encoded in BytesRef
* @param vector - float vector
*/
public static float getVectorMagnitude(Version indexVersion, BytesRef vectorBR, float[] vector) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
if (indexVersion.onOrAfter(Version.V_7_4_0)) { // decode vector magnitude
int offset = vectorBR.offset + vectorBR.length - 4;
int vectorMagnitudeIntValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) |
((vectorBR.bytes[offset++] & 0xFF) << 16) |
((vectorBR.bytes[offset++] & 0xFF) << 8) |
(vectorBR.bytes[offset++] & 0xFF);
float vectorMagnitude = Float.intBitsToFloat(vectorMagnitudeIntValue);
return vectorMagnitude;
} else { // calculate vector magnitude
double dotProduct = 0f;
for (int dim = 0; dim < vector.length; dim++) {
dotProduct += (double) vector[dim] * vector[dim];
}
return (float) Math.sqrt(dotProduct);
}
}
}
Loading