diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index fc1fa59824f99..503726f6bc871 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -61,6 +61,7 @@ private enum VectorSourceOptions { .collect(Collectors.toSet()); public static final float DELTA = 1e-7F; + public static final float BFLOAT16_DELTA = 1e-2F; private final ElementType elementType; private final DenseVectorFieldMapper.VectorSimilarity similarity; @@ -70,7 +71,7 @@ private enum VectorSourceOptions { @ParametersFactory public static Iterable parameters() throws Exception { List params = new ArrayList<>(); - for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT)) { + for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT, ElementType.BFLOAT16)) { // Test all similarities for (DenseVectorFieldMapper.VectorSimilarity similarity : DenseVectorFieldMapper.VectorSimilarity.values()) { if (elementType == ElementType.BIT && similarity != DenseVectorFieldMapper.VectorSimilarity.L2_NORM) { @@ -137,8 +138,10 @@ public void testRetrieveTopNDenseVectorFieldData() { } else { assertNotNull(actualVector); assertEquals(expectedVector.size(), actualVector.size()); + + float delta = elementType == ElementType.BFLOAT16 ? BFLOAT16_DELTA : DELTA; for (int i = 0; i < expectedVector.size(); i++) { - assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), DELTA); + assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), delta); } } }); @@ -167,12 +170,14 @@ public void testRetrieveDenseVectorFieldData() { } else { assertNotNull(actualVector); assertEquals(expectedVector.size(), actualVector.size()); + + float delta = elementType == ElementType.BFLOAT16 ? BFLOAT16_DELTA : DELTA; for (int i = 0; i < actualVector.size(); i++) { assertEquals( "Actual: " + actualVector + "; expected: " + expectedVector, expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), - DELTA + delta ); } } @@ -253,12 +258,13 @@ public void setup() throws IOException { } else { for (int j = 0; j < numDims; j++) { switch (elementType) { - case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true)); + case FLOAT, BFLOAT16 -> vector.add(randomFloatBetween(0F, 1F, true)); case BYTE, BIT -> vector.add((byte) randomIntBetween(-128, 127)); default -> throw new IllegalArgumentException("Unexpected element type: " + elementType); } } - if ((elementType == ElementType.FLOAT) && (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely())) { + if ((elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16) + && (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely())) { // Normalize the vector float magnitude = DenseVector.getMagnitude(vector); vector.replaceAll(number -> number.floatValue() / magnitude); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 111f7f0c8ee72..96ef347d70315 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -57,6 +57,7 @@ public static Iterable parameters() throws Exception { List params = new ArrayList<>(); for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { params.add(new Object[] { DenseVectorFieldMapper.ElementType.FLOAT, indexType }); + params.add(new Object[] { DenseVectorFieldMapper.ElementType.BFLOAT16, indexType }); } for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { params.add(new Object[] { DenseVectorFieldMapper.ElementType.BYTE, indexType }); @@ -264,7 +265,7 @@ public void setup() throws IOException { List vector = new ArrayList<>(numDims); for (int j = 0; j < numDims; j++) { switch (elementType) { - case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true)); + case FLOAT, BFLOAT16 -> vector.add(randomFloatBetween(0F, 1F, true)); case BYTE, BIT -> vector.add((byte) (randomFloatBetween(0F, 1F, true) * 127.0f)); default -> throw new IllegalArgumentException("Unexpected element type: " + elementType); }