diff --git a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java index 3ef32c993217f..1d1d29b6c232a 100644 --- a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java @@ -186,13 +186,13 @@ public void testMultiKnnClauses() throws IOException { createIndex("index", indexSettings, builder); for (int doc = 0; doc < 10; doc++) { - client().prepareIndex("index").setSource("vector", randomVector(), "text", "hello world", "number", 1).get(); - client().prepareIndex("index").setSource("vector_2", randomVector(), "text", "hello world", "number", 2).get(); + client().prepareIndex("index").setSource("vector", randomVector(1.0f, 2.0f), "text", "hello world", "number", 1).get(); + client().prepareIndex("index").setSource("vector_2", randomVector(20f, 21f), "text", "hello world", "number", 2).get(); client().prepareIndex("index").setSource("text", "goodnight world", "number", 3).get(); } client().admin().indices().prepareRefresh("index").get(); - float[] queryVector = randomVector(); + float[] queryVector = randomVector(20f, 21f); KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).boost(5.0f); KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50).boost(10.0f); SearchResponse response = client().prepareSearch("index") @@ -213,7 +213,7 @@ public void testMultiKnnClauses() throws IOException { assertThat(agg.getAvg(), equalTo(2.25)); assertThat(agg.getSum(), equalTo(45.0)); - // Because of the boost, vector_2 results should appear first + // Because of the boost & vector distributions, vector_2 results should appear first assertNotNull(response.getHits().getAt(0).field("vector_2")); } @@ -372,4 +372,12 @@ private float[] randomVector() { } return vector; } + + private float[] randomVector(float dimLower, float dimUpper) { + float[] vector = new float[VECTOR_DIMENSION]; + for (int i = 0; i < vector.length; i++) { + vector[i] = (float) randomDoubleBetween(dimLower, dimUpper, true); + } + return vector; + } }