From 5a824c5a38e51590064e410a292fa50cb0bb0d86 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Fri, 10 May 2024 17:39:04 +0200 Subject: [PATCH] [Inference API] Improve completion response entity tests (#108512) --- .../AzureOpenAiCompletionResponseEntityTests.java | 9 +++------ .../openai/OpenAiChatCompletionResponseEntityTests.java | 6 +++--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntityTests.java index 3afe4bd439e0f..ec76f43a6d52f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntityTests.java @@ -17,7 +17,6 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; -import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -50,7 +49,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { "index": 0, "logprobs": null, "message": { - "content": "response", + "content": "result", "role": "assistant" } } @@ -92,10 +91,8 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(chatCompletionResults.getResults().size(), equalTo(1)); - - ChatCompletionResults.Result result = chatCompletionResults.getResults().get(0); - assertThat(result.asMap().get(result.getResultsField()), is("response")); + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("result")); } public void testFromResponse_FailsWhenChoicesFieldIsNotPresent() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiChatCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiChatCompletionResponseEntityTests.java index 080602e8fd245..5604d6573144e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiChatCompletionResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiChatCompletionResponseEntityTests.java @@ -17,7 +17,6 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; -import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -35,7 +34,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { "index": 0, "message": { "role": "assistant", - "content": "some content" + "content": "result" }, "logprobs": null, "finish_reason": "stop" @@ -55,7 +54,8 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(chatCompletionResults.getResults().size(), equalTo(1)); + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("result")); } public void testFromResponse_FailsWhenChoicesFieldIsNotPresent() {