From 5fda71ba1c14a5a6966ac6a2986ce309e24fc140 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 9 May 2024 09:51:45 -0400 Subject: [PATCH] improvements from DK review --- .../xpack/inference/qa/mixed/Clusters.java | 1 - .../qa/mixed/CohereServiceMixedIT.java | 56 +++++-------------- .../qa/mixed/HuggingFaceServiceMixedIT.java | 16 +++--- .../qa/mixed/OpenAIServiceMixedIT.java | 16 +++--- .../application/OpenAiServiceUpgradeIT.java | 1 - 5 files changed, 31 insertions(+), 59 deletions(-) diff --git a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/Clusters.java b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/Clusters.java index 870d64d7603e5..d7c0a73c9de4e 100644 --- a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/Clusters.java +++ b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/Clusters.java @@ -20,7 +20,6 @@ public static ElasticsearchCluster mixedVersionCluster() { .withNode(node -> node.version(Version.CURRENT)) .setting("xpack.security.enabled", "false") .setting("xpack.license.self_generated.type", "trial") - .setting("cluster.routing.rebalance.enable", "none") // disable relocation until we have retry in ESQL .build(); } } diff --git a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java index c157c8135ad7c..5412339586b51 100644 --- a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java +++ b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java @@ -57,17 +57,17 @@ public void testCohereEmbeddings() throws IOException { var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_EMBEDDINGS_ADDED)); assumeTrue("Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED, embeddingsSupported); - final String oldClusterIdInt8 = "old-cluster-embeddings-int8"; - final String oldClusterIdFloat = "old-cluster-embeddings-float"; + final String inferenceIdInt8 = "mixed-cluster-cohere-embeddings-int8"; + final String inferenceIdFloat = "mixed-cluster-cohere-embeddings-float"; // queue a response as PUT will call the service cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); - put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); + put(inferenceIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); // float model cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat())); - put(oldClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); + put(inferenceIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterIdInt8).get("endpoints"); + var configs = (List>) get(TaskType.TEXT_EMBEDDING, inferenceIdInt8).get("endpoints"); assertEquals("cohere", configs.get(0).get("service")); var serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0")); @@ -75,41 +75,15 @@ public void testCohereEmbeddings() throws IOException { // An upgraded node will report the embedding type as byte, an old node int8 assertThat(embeddingType, Matchers.is(oneOf("int8", "byte"))); - configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterIdFloat).get("endpoints"); + configs = (List>) get(TaskType.TEXT_EMBEDDING, inferenceIdFloat).get("endpoints"); serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("embedding_type", "float")); - assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE); - assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT); + assertEmbeddingInference(inferenceIdInt8, CohereEmbeddingType.BYTE); + assertEmbeddingInference(inferenceIdFloat, CohereEmbeddingType.FLOAT); - { - final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8"; - - cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); - put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); - - configs = (List>) get(TaskType.TEXT_EMBEDDING, upgradedClusterIdInt8).get("endpoints"); - serviceSettings = (Map) configs.get(0).get("service_settings"); - assertThat(serviceSettings, hasEntry("embedding_type", "byte")); // int8 rewritten to byte - - assertEmbeddingInference(upgradedClusterIdInt8, CohereEmbeddingType.INT8); - delete(upgradedClusterIdInt8); - } - { - final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float"; - cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat())); - put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); - - configs = (List>) get(TaskType.TEXT_EMBEDDING, upgradedClusterIdFloat).get("endpoints"); - serviceSettings = (Map) configs.get(0).get("service_settings"); - assertThat(serviceSettings, hasEntry("embedding_type", "float")); - - assertEmbeddingInference(upgradedClusterIdFloat, CohereEmbeddingType.FLOAT); - delete(upgradedClusterIdFloat); - } - - delete(oldClusterIdFloat); - delete(oldClusterIdInt8); + delete(inferenceIdFloat); + delete(inferenceIdInt8); } @@ -132,12 +106,12 @@ public void testRerank() throws IOException { var rerankSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_RERANK_ADDED)); assumeTrue("Cohere rerank service added in " + COHERE_RERANK_ADDED, rerankSupported); - final String oldClusterId = "old-cluster-rerank"; + final String inferenceId = "mixed-cluster-rerank"; - put(oldClusterId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK); - assertRerank(oldClusterId); + put(inferenceId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK); + assertRerank(inferenceId); - var configs = (List>) get(TaskType.RERANK, oldClusterId).get("endpoints"); + var configs = (List>) get(TaskType.RERANK, inferenceId).get("endpoints"); assertThat(configs, hasSize(1)); assertEquals("cohere", configs.get(0).get("service")); var serviceSettings = (Map) configs.get(0).get("service_settings"); @@ -145,7 +119,7 @@ public void testRerank() throws IOException { var taskSettings = (Map) configs.get(0).get("task_settings"); assertThat(taskSettings, hasEntry("top_n", 3)); - assertRerank(oldClusterId); + assertRerank(inferenceId); } diff --git a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java index e2de60d7393dd..9eed1b8a7fcd3 100644 --- a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java +++ b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java @@ -50,14 +50,14 @@ public void testHFEmbeddings() throws IOException { var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(HF_EMBEDDINGS_ADDED)); assumeTrue("Hugging Face embedding service added in " + HF_EMBEDDINGS_ADDED, embeddingsSupported); - final String oldClusterId = "old-cluster-embeddings"; + final String inferenceId = "mixed-cluster-embeddings"; embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse())); - put(oldClusterId, embeddingConfig(getUrl(embeddingsServer)), TaskType.TEXT_EMBEDDING); - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints"); + put(inferenceId, embeddingConfig(getUrl(embeddingsServer)), TaskType.TEXT_EMBEDDING); + var configs = (List>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints"); assertThat(configs, hasSize(1)); assertEquals("hugging_face", configs.get(0).get("service")); - assertEmbeddingInference(oldClusterId); + assertEmbeddingInference(inferenceId); } void assertEmbeddingInference(String inferenceId) throws IOException { @@ -71,15 +71,15 @@ public void testElser() throws IOException { var supported = bwcVersion.onOrAfter(Version.fromString(HF_ELSER_ADDED)); assumeTrue("HF elser service added in " + HF_ELSER_ADDED, supported); - final String oldClusterId = "old-cluster-elser"; + final String inferenceId = "mixed-cluster-elser"; final String upgradedClusterId = "upgraded-cluster-elser"; - put(oldClusterId, elserConfig(getUrl(elserServer)), TaskType.SPARSE_EMBEDDING); + put(inferenceId, elserConfig(getUrl(elserServer)), TaskType.SPARSE_EMBEDDING); - var configs = (List>) get(TaskType.SPARSE_EMBEDDING, oldClusterId).get("endpoints"); + var configs = (List>) get(TaskType.SPARSE_EMBEDDING, inferenceId).get("endpoints"); assertThat(configs, hasSize(1)); assertEquals("hugging_face", configs.get(0).get("service")); - assertElser(oldClusterId); + assertElser(inferenceId); } private void assertElser(String inferenceId) throws IOException { diff --git a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java index 013a2bf0d4784..edf0b97f40c93 100644 --- a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java +++ b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java @@ -54,14 +54,14 @@ public void testOpenAiEmbeddings() throws IOException { var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED)); assumeTrue("OpenAI embedding service added in " + OPEN_AI_EMBEDDINGS_ADDED, openAiEmbeddingsSupported); - final String oldClusterId = "old-cluster-embeddings"; + final String inferenceId = "mixed-cluster-embeddings"; String inferenceConfig = oldClusterVersionCompatibleEmbeddingConfig(); // queue a response as PUT will call the service openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse())); - put(oldClusterId, inferenceConfig, TaskType.TEXT_EMBEDDING); + put(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING); - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints"); + var configs = (List>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints"); assertThat(configs, hasSize(1)); assertEquals("openai", configs.get(0).get("service")); var serviceSettings = (Map) configs.get(0).get("service_settings"); @@ -69,7 +69,7 @@ public void testOpenAiEmbeddings() throws IOException { var modelIdFound = serviceSettings.containsKey("model_id") || taskSettings.containsKey("model_id"); assertTrue("model_id not found in config: " + configs.toString(), modelIdFound); - assertEmbeddingInference(oldClusterId); + assertEmbeddingInference(inferenceId); } void assertEmbeddingInference(String inferenceId) throws IOException { @@ -83,12 +83,12 @@ public void testOpenAiCompletions() throws IOException { var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED)); assumeTrue("OpenAI completions service added in " + OPEN_AI_COMPLETIONS_ADDED, openAiEmbeddingsSupported); - final String oldClusterId = "old-cluster-completions"; + final String inferenceId = "mixed-cluster-completions"; final String upgradedClusterId = "upgraded-cluster-completions"; - put(oldClusterId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION); + put(inferenceId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION); - var configsMap = get(TaskType.COMPLETION, oldClusterId); + var configsMap = get(TaskType.COMPLETION, inferenceId); logger.warn("Configs: {}", configsMap); var configs = (List>) configsMap.get("endpoints"); assertThat(configs, hasSize(1)); @@ -98,7 +98,7 @@ public void testOpenAiCompletions() throws IOException { var taskSettings = (Map) configs.get(0).get("task_settings"); assertThat(taskSettings.keySet(), empty()); - assertCompletionInference(oldClusterId); + assertCompletionInference(inferenceId); } void assertCompletionInference(String inferenceId) throws IOException { diff --git a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/OpenAiServiceUpgradeIT.java b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/OpenAiServiceUpgradeIT.java index 82bba94cc2607..4e8e1c845b070 100644 --- a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/OpenAiServiceUpgradeIT.java +++ b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/OpenAiServiceUpgradeIT.java @@ -120,7 +120,6 @@ public void testOpenAiCompletions() throws IOException { final String upgradedClusterId = "upgraded-cluster-completions"; if (isOldCluster()) { - // TODO why is put only in old cluster? put(oldClusterId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION); var configs = (List>) get(TaskType.COMPLETION, oldClusterId).get("models");