Skip to content

Commit

Permalink
Faiss SQFP16 Range Validation and Clipping (#1562)
Browse files Browse the repository at this point in the history
* Add Range Validation for SQFP16 (#1493)

* Add Range Validation for SQFP16 Vector Data

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* Add index setting to clip vector data to FP16 range

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* Add CHANGELOG

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* Add an encoder parameter to clip fp16 range

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* Address Review Comments

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* Add BWC Tests

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

---------

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* SQFP16 Range Validation for Faiss IVF Models (#1557)

* SQFP16 Range Validation for Faiss IVF Models

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* Address Review Comments

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

---------

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* Rebase Changes

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

---------

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
  • Loading branch information
naveentatikonda authored Mar 18, 2024
1 parent 4b0078d commit d63ce27
Show file tree
Hide file tree
Showing 11 changed files with 1,007 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Validate zero vector when using cosine metric [#1501](https://github.com/opensearch-project/k-NN/pull/1501)
* Persist model definition in model metadata [#1527] (https://github.com/opensearch-project/k-NN/pull/1527)
* Added Inner Product Space type support for Lucene Engine [#1551](https://github.com/opensearch-project/k-NN/pull/1551)
* Add Range Validation for Faiss SQFP16 [#1493](https://github.com/opensearch-project/k-NN/pull/1493)
* SQFP16 Range Validation for Faiss IVF Models [#1557](https://github.com/opensearch-project/k-NN/pull/1557)
### Bug Fixes
* Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518)
* Switch SpaceType.INNERPRODUCT's vector similarity function to MAXIMUM_INNER_PRODUCT [#1532](https://github.com/opensearch-project/k-NN/pull/1532)
Expand Down
390 changes: 390 additions & 0 deletions qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ public class KNNConstants {
public static final String FAISS_SQ_TYPE = "type";
public static final String FAISS_SQ_ENCODER_FP16 = "fp16";
public static final List<String> FAISS_SQ_ENCODER_TYPES = List.of(FAISS_SQ_ENCODER_FP16);
public static final String FAISS_SQ_CLIP = "clip";

// Parameter defaults/limits
public static final Integer ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT = 1;
Expand All @@ -111,6 +112,9 @@ public class KNNConstants {
public static final Integer MODEL_CACHE_CAPACITY_ATROPHY_THRESHOLD_IN_MINUTES = 30;
public static final Integer MODEL_CACHE_EXPIRE_AFTER_ACCESS_TIME_MINUTES = 30;

public static final Float FP16_MAX_VALUE = 65504.0f;
public static final Float FP16_MIN_VALUE = -65504.0f;

// Lib names
private static final String JNI_LIBRARY_PREFIX = "opensearchknn_";
public static final String FAISS_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + FAISS_NAME;
Expand Down
25 changes: 25 additions & 0 deletions src/main/java/org/opensearch/knn/index/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,31 @@ public T getDefaultValue() {
*/
public abstract ValidationException validate(Object value);

/**
* Boolean method parameter
*/
public static class BooleanParameter extends Parameter<Boolean> {
public BooleanParameter(String name, Boolean defaultValue, Predicate<Boolean> validator) {
super(name, defaultValue, validator);
}

@Override
public ValidationException validate(Object value) {
ValidationException validationException = null;
if (!(value instanceof Boolean)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName()));
return validationException;
}

if (!validator.test((Boolean) value)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("parameter validation failed for Boolean parameter [%s].", getName()));
}
return validationException;
}
}

/**
* Integer method parameter
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,39 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting;

Expand Down Expand Up @@ -511,10 +529,23 @@ protected String contentType() {

@Override
protected void parseCreateField(ParseContext context) throws IOException {
parseCreateField(context, fieldType().getDimension(), fieldType().getSpaceType());
parseCreateField(
context,
fieldType().getDimension(),
fieldType().getSpaceType(),
getMethodComponentContext(fieldType().getKnnMethodContext())
);
}

protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException {
private MethodComponentContext getMethodComponentContext(KNNMethodContext knnMethodContext) {
if (Objects.isNull(knnMethodContext)) {
return null;
}
return knnMethodContext.getMethodComponentContext();
}

protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext)
throws IOException {

validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();
Expand All @@ -532,7 +563,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

if (floatsArrayOptional.isEmpty()) {
return;
Expand All @@ -551,6 +582,47 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
context.path().remove();
}

// Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16"
protected boolean isFaissSQfp16(MethodComponentContext methodComponentContext) {
if (Objects.isNull(methodComponentContext)) {
return false;
}

if (methodComponentContext.getParameters().size() == 0) {
return false;
}

Map<String, Object> methodComponentParams = methodComponentContext.getParameters();

// The method component parameters should have an encoder
if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) {
return false;
}

// Validate if the object is of type MethodComponentContext before casting it later
if (!(methodComponentParams.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) {
return false;
}

MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER);

// returns true if encoder name is "sq" and type is "fp16"
return ENCODER_SQ.equals(encoderMethodComponentContext.getName())
&& FAISS_SQ_ENCODER_FP16.equals(
encoderMethodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
);

}

// Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index
// using "sq" encoder of type "fp16".
protected boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) {
if (Objects.nonNull(methodComponentContext)) {
return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false);
}
return false;
}

void validateIfCircuitBreakerIsNotTriggered() {
if (KNNSettings.isCircuitBreakerTriggered()) {
throw new IllegalStateException(
Expand Down Expand Up @@ -600,23 +672,53 @@ Optional<byte[]> getBytesFromContext(ParseContext context, int dimension) throws
return Optional.of(array);
}

Optional<float[]> getFloatsFromContext(ParseContext context, int dimension) throws IOException {
Optional<float[]> getFloatsFromContext(ParseContext context, int dimension, MethodComponentContext methodComponentContext)
throws IOException {
context.path().add(simpleName());

// Returns an optional array of float values where each value in the vector is parsed as a float and validated
// if it is a finite number and within the fp16 range of [-65504 to 65504] by default if Faiss encoder is SQ and type is 'fp16'.
// If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be
// clipped to FP16 range.
boolean isFaissSQfp16Flag = isFaissSQfp16(methodComponentContext);
boolean clipVectorValueToFP16RangeFlag = false;
if (isFaissSQfp16Flag) {
clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER)
);
}

ArrayList<Float> vector = new ArrayList<>();
XContentParser.Token token = context.parser().currentToken();
float value;
if (token == XContentParser.Token.START_ARRAY) {
token = context.parser().nextToken();
while (token != XContentParser.Token.END_ARRAY) {
value = context.parser().floatValue();
validateFloatVectorValue(value);
if (isFaissSQfp16Flag) {
if (clipVectorValueToFP16RangeFlag) {
value = clipVectorValueToFP16Range(value);
} else {
validateFP16VectorValue(value);
}
} else {
validateFloatVectorValue(value);
}

vector.add(value);
token = context.parser().nextToken();
}
} else if (token == XContentParser.Token.VALUE_NUMBER) {
value = context.parser().floatValue();
validateFloatVectorValue(value);
if (isFaissSQfp16Flag) {
if (clipVectorValueToFP16RangeFlag) {
value = clipVectorValueToFP16Range(value);
} else {
validateFP16VectorValue(value);
}
} else {
validateFloatVectorValue(value);
}
vector.add(value);
context.parser().nextToken();
} else if (token == XContentParser.Token.VALUE_NULL) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,55 @@

import java.util.Locale;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE;
import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KNNVectorFieldMapperUtil {

/**
* Validate the float vector value and throw exception if it is not a number or not in the finite range
* or is not within the FP16 range of [-65504 to 65504].
*
* @param value float vector value
*/
public static void validateFP16VectorValue(float value) {
validateFloatVectorValue(value);

if (value < FP16_MIN_VALUE || value > FP16_MAX_VALUE) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]",
ENCODER_SQ,
FAISS_SQ_ENCODER_FP16,
FP16_MIN_VALUE,
FP16_MAX_VALUE
)
);
}
}

/**
* Validate the float vector value and if it is outside FP16 range,
* then it will be clipped to FP16 range of [-65504 to 65504].
*
* @param value float vector value
* @return vector value clipped to FP16 range
*/
public static float clipVectorValueToFP16Range(float value) {
validateFloatVectorValue(value);
if (value < FP16_MIN_VALUE) return FP16_MIN_VALUE;
if (value > FP16_MAX_VALUE) return FP16_MAX_VALUE;
return value;
}

/**
* Validates and throws exception if data_type field is set in the index mapping
* using any VectorDataType (other than float, which is default) because other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.common.Explicit;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
Expand Down Expand Up @@ -75,7 +76,8 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper {
}

@Override
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException {
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext)
throws IOException {

validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();
Expand All @@ -96,7 +98,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
context.doc().add(new VectorField(name(), array, vectorFieldType));
}
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

if (floatsArrayOptional.isEmpty()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ protected void parseCreateField(ParseContext context) throws IOException {
);
}

parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType());
parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext());
}
}
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import static org.opensearch.knn.common.KNNConstants.FAISS_HNSW_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_IVF_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_PQ_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_TYPES;
Expand Down Expand Up @@ -90,6 +91,7 @@ class Faiss extends NativeLibrary {
FAISS_SQ_TYPE,
new Parameter.StringParameter(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16, FAISS_SQ_ENCODER_TYPES::contains)
)
.addParameter(FAISS_SQ_CLIP, new Parameter.BooleanParameter(FAISS_SQ_CLIP, false, Objects::nonNull))
.setMapGenerator(
((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder(
FAISS_SQ_DESCRIPTION,
Expand Down
Loading

0 comments on commit d63ce27

Please sign in to comment.