Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ private enum VectorSourceOptions {
.collect(Collectors.toSet());

public static final float DELTA = 1e-7F;
public static final float BFLOAT16_DELTA = 1e-2F;

private final ElementType elementType;
private final DenseVectorFieldMapper.VectorSimilarity similarity;
Expand All @@ -70,7 +71,7 @@ private enum VectorSourceOptions {
@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
List<Object[]> params = new ArrayList<>();
for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT)) {
for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT, ElementType.BFLOAT16)) {
// Test all similarities
for (DenseVectorFieldMapper.VectorSimilarity similarity : DenseVectorFieldMapper.VectorSimilarity.values()) {
if (elementType == ElementType.BIT && similarity != DenseVectorFieldMapper.VectorSimilarity.L2_NORM) {
Expand Down Expand Up @@ -137,8 +138,10 @@ public void testRetrieveTopNDenseVectorFieldData() {
} else {
assertNotNull(actualVector);
assertEquals(expectedVector.size(), actualVector.size());

float delta = elementType == ElementType.BFLOAT16 ? BFLOAT16_DELTA : DELTA;
for (int i = 0; i < expectedVector.size(); i++) {
assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), DELTA);
assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), delta);
}
}
});
Expand Down Expand Up @@ -167,12 +170,14 @@ public void testRetrieveDenseVectorFieldData() {
} else {
assertNotNull(actualVector);
assertEquals(expectedVector.size(), actualVector.size());

float delta = elementType == ElementType.BFLOAT16 ? BFLOAT16_DELTA : DELTA;
for (int i = 0; i < actualVector.size(); i++) {
assertEquals(
"Actual: " + actualVector + "; expected: " + expectedVector,
expectedVector.get(i).floatValue(),
actualVector.get(i).floatValue(),
DELTA
delta
);
}
}
Expand Down Expand Up @@ -253,12 +258,13 @@ public void setup() throws IOException {
} else {
for (int j = 0; j < numDims; j++) {
switch (elementType) {
case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true));
case FLOAT, BFLOAT16 -> vector.add(randomFloatBetween(0F, 1F, true));
case BYTE, BIT -> vector.add((byte) randomIntBetween(-128, 127));
default -> throw new IllegalArgumentException("Unexpected element type: " + elementType);
}
}
if ((elementType == ElementType.FLOAT) && (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely())) {
if ((elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16)
&& (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely())) {
// Normalize the vector
float magnitude = DenseVector.getMagnitude(vector);
vector.replaceAll(number -> number.floatValue() / magnitude);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public static Iterable<Object[]> parameters() throws Exception {
List<Object[]> params = new ArrayList<>();
for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) {
params.add(new Object[] { DenseVectorFieldMapper.ElementType.FLOAT, indexType });
params.add(new Object[] { DenseVectorFieldMapper.ElementType.BFLOAT16, indexType });
}
for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) {
params.add(new Object[] { DenseVectorFieldMapper.ElementType.BYTE, indexType });
Expand Down Expand Up @@ -264,7 +265,7 @@ public void setup() throws IOException {
List<Number> vector = new ArrayList<>(numDims);
for (int j = 0; j < numDims; j++) {
switch (elementType) {
case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true));
case FLOAT, BFLOAT16 -> vector.add(randomFloatBetween(0F, 1F, true));
case BYTE, BIT -> vector.add((byte) (randomFloatBetween(0F, 1F, true) * 127.0f));
default -> throw new IllegalArgumentException("Unexpected element type: " + elementType);
}
Expand Down