diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index cffb12e0713da..1b01f37955e5a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -299,17 +299,13 @@ public void testUnsupportedStream() throws Exception { try { var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID()), null); - assertThat(events.size(), equalTo(2)); + assertThat(events.size(), equalTo(1)); events.forEach(event -> { - switch (event.name()) { - case EVENT -> assertThat(event.value(), equalToIgnoringCase("error")); - case DATA -> assertThat( - event.value(), - containsString( - "Streaming is not allowed for service [streaming_completion_test_service] and task [sparse_embedding]" - ) - ); - } + assertThat(event.type(), equalToIgnoringCase("error")); + assertThat( + event.data(), + containsString("Streaming is not allowed for service [streaming_completion_test_service] and task [sparse_embedding]") + ); }); } finally { deleteModel(modelId); @@ -331,12 +327,10 @@ public void testSupportedStream() throws Exception { input.stream().map(s -> s.toUpperCase(Locale.ROOT)).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"), Stream.of("[DONE]") ).iterator(); - assertThat(events.size(), equalTo((input.size() + 1) * 2)); + assertThat(events.size(), equalTo(input.size() + 1)); events.forEach(event -> { - switch (event.name()) { - case EVENT -> assertThat(event.value(), equalToIgnoringCase("message")); - case DATA -> assertThat(event.value(), equalTo(expectedResponses.next())); - } + assertThat(event.type(), equalToIgnoringCase("message")); + assertThat(event.data(), equalTo(expectedResponses.next())); }); } finally { deleteModel(modelId); @@ -359,12 +353,10 @@ public void testUnifiedCompletionInference() throws Exception { VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER ); var expectedResponses = expectedResultsIterator(input); - assertThat(events.size(), equalTo((input.size() + 1) * 2)); + assertThat(events.size(), equalTo(input.size() + 1)); events.forEach(event -> { - switch (event.name()) { - case EVENT -> assertThat(event.value(), equalToIgnoringCase("message")); - case DATA -> assertThat(event.value(), equalTo(expectedResponses.next())); - } + assertThat(event.type(), equalToIgnoringCase("message")); + assertThat(event.data(), equalTo(expectedResponses.next())); }); } finally { deleteModel(modelId); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java index f837ff5c4049d..b480b2327323d 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java @@ -51,7 +51,6 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.XContentFormattedException; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; import java.io.IOException; @@ -364,9 +363,8 @@ private static class RandomStringCollector { private void collect(String str) throws IOException { sseParser.parse(str.getBytes(StandardCharsets.UTF_8)) .stream() - .filter(event -> event.name() == ServerSentEventField.DATA) - .filter(ServerSentEvent::hasValue) - .map(ServerSentEvent::value) + .filter(ServerSentEvent::hasData) + .map(ServerSentEvent::data) .forEach(stringsVerified::offer); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java index eda3fc0f3bfdb..6160f51709299 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java @@ -11,7 +11,6 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; import java.io.IOException; import java.util.ArrayDeque; @@ -40,7 +39,7 @@ public static Deque parseEvent( ) throws Exception { var results = new ArrayDeque(item.size()); for (ServerSentEvent event : item) { - if (ServerSentEventField.DATA == event.name() && event.hasValue()) { + if (event.hasData()) { try { var delta = parseFunction.apply(parserConfig, event); delta.forEachRemaining(results::offer); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicStreamingProcessor.java index 8625f5d6123c5..d1af805a00388 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicStreamingProcessor.java @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; import java.io.IOException; import java.util.ArrayDeque; @@ -42,8 +41,8 @@ protected void next(Deque item) throws Exception { var results = new ArrayDeque(item.size()); for (var event : item) { - if (event.name() == ServerSentEventField.DATA && event.hasValue()) { - try (var parser = parser(event.value())) { + if (event.hasData()) { + try (var parser = parser(event.data())) { var eventType = eventType(parser); switch (eventType) { case "error" -> { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioStreamingProcessor.java index aa1232f4182e3..3ff6b10db5d60 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioStreamingProcessor.java @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; import java.io.IOException; import java.util.ArrayDeque; @@ -37,8 +36,8 @@ protected void next(Deque item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); var results = new ArrayDeque(item.size()); for (ServerSentEvent event : item) { - if (ServerSentEventField.DATA == event.name() && event.hasValue()) { - try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) { + if (event.hasData()) { + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) { var delta = content.apply(jsonParser); results.offer(new StreamingChatCompletionResults.Result(delta)); } catch (Exception e) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java index fcfd8e19004c5..d708f8bbe948a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java @@ -124,11 +124,11 @@ protected void next(Deque item) throws Exception { private static Iterator parse(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException { - if (DONE_MESSAGE.equalsIgnoreCase(event.value())) { + if (DONE_MESSAGE.equalsIgnoreCase(event.data())) { return Collections.emptyIterator(); } - try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) { + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) { moveToFirstToken(jsonParser); XContentParser.Token token = jsonParser.currentToken(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java index 59c17e890e9f5..97e7d312f9577 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; import java.io.IOException; import java.util.ArrayDeque; @@ -62,7 +61,6 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor< private final BiFunction errorParser; private final Deque buffer = new LinkedBlockingDeque<>(); - private volatile boolean previousEventWasError = false; public OpenAiUnifiedStreamingProcessor(BiFunction errorParser) { this.errorParser = errorParser; @@ -83,19 +81,15 @@ protected void next(Deque item) throws Exception { var results = new ArrayDeque(item.size()); for (var event : item) { - if (ServerSentEventField.EVENT == event.name() && "error".equals(event.value())) { - previousEventWasError = true; - } else if (ServerSentEventField.DATA == event.name() && event.hasValue()) { - if (previousEventWasError) { - throw errorParser.apply(event.value(), null); - } - + if ("error".equals(event.type()) && event.hasData()) { + throw errorParser.apply(event.data(), null); + } else if (event.hasData()) { try { var delta = parse(parserConfig, event); delta.forEachRemaining(results::offer); } catch (Exception e) { logger.warn("Failed to parse event from inference provider: {}", event); - throw errorParser.apply(event.value(), e); + throw errorParser.apply(event.data(), e); } } } @@ -118,11 +112,11 @@ private static Iterator possibleValues = Arrays.stream(values()) - .map(Enum::name) - .map(name -> name.toLowerCase(Locale.ROOT)) - .collect(Collectors.toSet()); - - static Optional oneOf(String name) { - if (name != null && possibleValues.contains(name.toLowerCase(Locale.ROOT))) { - return Optional.of(valueOf(name.toUpperCase(Locale.ROOT))); - } else { - return Optional.empty(); - } - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParser.java index 9856a116f17d2..8af44c19a176a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParser.java @@ -10,8 +10,8 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayDeque; import java.util.Deque; +import java.util.Locale; import java.util.Optional; -import java.util.regex.Pattern; /** * https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation @@ -20,11 +20,14 @@ * If the line starts with a colon, we discard this event. * If the line contains a colon, we process it into {@link ServerSentEvent} with a non-empty value. * If the line does not contain a colon, we process it into {@link ServerSentEvent}with an empty string value. - * If the line's field is not one of {@link ServerSentEventField}, we discard this event. + * If the line's field is not one of (data, event), we discard this event. `id` and `retry` are not implemented, because we do not use them + * and have no plans to use them. */ public class ServerSentEventParser { - private static final Pattern END_OF_LINE_REGEX = Pattern.compile("\\n|\\r|\\r\\n"); private static final String BOM = "\uFEFF"; + private static final String TYPE_FIELD = "event"; + private static final String DATA_FIELD = "data"; + private final EventBuffer eventBuffer = new EventBuffer(); private volatile String previousTokens = ""; public Deque parse(byte[] bytes) { @@ -33,43 +36,101 @@ public Deque parse(byte[] bytes) { } var body = previousTokens + new String(bytes, StandardCharsets.UTF_8); - var lines = END_OF_LINE_REGEX.split(body, -1); // -1 because we actually want trailing empty strings + var lines = body.lines(); - var collector = new ArrayDeque(lines.length); - for (var i = 0; i < lines.length - 1; i++) { - var line = lines[i].replace(BOM, ""); + var collector = new ArrayDeque(); + lines.reduce((previousLine, nextLine) -> { + var line = previousLine.replace(BOM, ""); - if (line.isBlank() == false && line.startsWith(":") == false) { + if (line.isEmpty()) { + eventBuffer.dispatch().ifPresent(collector::offer); + } else if (line.startsWith(":") == false) { if (line.contains(":")) { - fieldValueEvent(line).ifPresent(collector::offer); - } else { - ServerSentEventField.oneOf(line).map(ServerSentEvent::new).ifPresent(collector::offer); + fieldValueEvent(line); + } else if (DATA_FIELD.equals(line.toLowerCase(Locale.ROOT))) { + eventBuffer.data(""); } } - } - - // we can sometimes get bytes for incomplete messages, so we save them for the next onNext invocation - // if we get an onComplete before we clear this cache, we follow the spec to treat it as an incomplete event and discard it since - // it was not followed by a blank line - previousTokens = lines[lines.length - 1]; + return nextLine; + }).ifPresent(lastLine -> { + if (lastLine.isEmpty()) { + // if the last line is an empty line, then we dispatch the event and clear the previousToken cache + eventBuffer.dispatch().ifPresent(collector::offer); + previousTokens = ""; + } else { + // we can sometimes get bytes for incomplete messages, so we save them for the next onNext invocation + // if we get an onComplete before we clear this cache, we follow the spec to treat it as an incomplete event and discard it + // since it was not followed by a blank line + previousTokens = lastLine; + } + }); return collector; } - private Optional fieldValueEvent(String lineWithColon) { + private void fieldValueEvent(String lineWithColon) { var firstColon = lineWithColon.indexOf(":"); - var fieldStr = lineWithColon.substring(0, firstColon); - var serverSentField = ServerSentEventField.oneOf(fieldStr); - - if ((firstColon + 1) != lineWithColon.length()) { - var value = lineWithColon.substring(firstColon + 1); - if (value.equals(" ") == false) { - var trimmedValue = value.charAt(0) == ' ' ? value.substring(1) : value; - return serverSentField.map(field -> new ServerSentEvent(field, trimmedValue)); + var fieldStr = lineWithColon.substring(0, firstColon).toLowerCase(Locale.ROOT); + + var value = lineWithColon.substring(firstColon + 1); + // "If value starts with a U+0020 SPACE character, remove it from value." + var trimmedValue = value.length() > 0 && value.charAt(0) == ' ' ? value.substring(1) : value; + + if (DATA_FIELD.equals(fieldStr)) { + eventBuffer.data(trimmedValue); + } else if (TYPE_FIELD.equals(fieldStr)) { + eventBuffer.type(trimmedValue); + } + } + + private static class EventBuffer { + private static final char LINE_FEED = '\n'; + private static final String MESSAGE = "message"; + private StringBuilder type = new StringBuilder(); + private StringBuilder data = new StringBuilder(); + private boolean appendLineFeed = false; + + private void type(String type) { + this.type.append(type); + } + + private void data(String data) { + // "Append the field value to the data buffer, then append a single U+000A LINE FEED (LF) character to the data buffer." + // But then we're told "If the data buffer's last character is a U+000A LINE FEED (LF) character, + // then remove the last character from the data buffer." + // Rather than add + remove for single-line data fields, we only append a LINE FEED on subsequent data lines. + if (appendLineFeed) { + this.data.append(LINE_FEED); + } else { + appendLineFeed = true; } + this.data.append(data); } - // if we have "data:" or "data: ", treat it like a no-value line - return serverSentField.map(ServerSentEvent::new); + private Optional dispatch() { + // "If the data buffer is an empty string, set the data buffer and the event type buffer to the empty string and return." + // We don't process empty events anywhere, so we just drop the events. + if (data.isEmpty()) { + reset(); + return Optional.empty(); + } + var dataValue = data.toString(); + + // "Initialize event's type attribute to "message"" + // "If the event type buffer has a value other than the empty string, + // change the type of the newly created event to equal the value of the event type buffer." + var typeValue = type.toString(); + typeValue = typeValue.isBlank() ? MESSAGE : typeValue; + + reset(); + + return Optional.of(new ServerSentEvent(typeValue, dataValue)); + } + + private void reset() { + type = new StringBuilder(); + data = new StringBuilder(); + appendLineFeed = false; + } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessorTests.java index 90d0e8742f733..4183649e3f52a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessorTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; import java.io.IOException; import java.util.ArrayDeque; @@ -33,7 +32,7 @@ public class OpenAiStreamingProcessorTests extends ESTestCase { public void testParseOpenAiResponse() throws IOException { var item = new ArrayDeque(); - item.offer(new ServerSentEvent(ServerSentEventField.DATA, """ + item.offer(new ServerSentEvent(""" { "id":"12345", "object":"chat.completion.chunk", @@ -62,7 +61,7 @@ public void testParseOpenAiResponse() throws IOException { public void testParseWithFinish() throws IOException { var item = new ArrayDeque(); - item.offer(new ServerSentEvent(ServerSentEventField.DATA, """ + item.offer(new ServerSentEvent(""" { "id":"12345", "object":"chat.completion.chunk", @@ -81,7 +80,7 @@ public void testParseWithFinish() throws IOException { ] } """)); - item.offer(new ServerSentEvent(ServerSentEventField.DATA, """ + item.offer(new ServerSentEvent(""" { "id":"12345", "object":"chat.completion.chunk", @@ -108,7 +107,7 @@ public void testParseWithFinish() throws IOException { public void testParseErrorCallsOnError() { var item = new ArrayDeque(); - item.offer(new ServerSentEvent(ServerSentEventField.DATA, "this isn't json")); + item.offer(new ServerSentEvent("this isn't json")); var exception = onError(new OpenAiStreamingProcessor(), item); assertThat(exception, instanceOf(XContentParseException.class)); @@ -133,7 +132,7 @@ public void testEmptyResultsRequestsMoreData() throws Exception { public void testDoneMessageIsIgnored() throws Exception { var item = new ArrayDeque(); - item.offer(new ServerSentEvent(ServerSentEventField.DATA, "[DONE]")); + item.offer(new ServerSentEvent("[DONE]")); var processor = new OpenAiStreamingProcessor(); @@ -151,7 +150,7 @@ public void testDoneMessageIsIgnored() throws Exception { public void testInitialLlamaResponseIsIgnored() throws Exception { var item = new ArrayDeque(); - item.offer(new ServerSentEvent(ServerSentEventField.DATA, """ + item.offer(new ServerSentEvent(""" { "id":"12345", "object":"chat.completion.chunk", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParserTests.java index 863d2e3e07c5f..d7cff63d8c708 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParserTests.java @@ -10,26 +10,49 @@ import org.elasticsearch.test.ESTestCase; import java.nio.charset.StandardCharsets; -import java.util.Arrays; import java.util.Deque; import java.util.List; -import java.util.Locale; -import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; public class ServerSentEventParserTests extends ESTestCase { - public void testParseEvents() { - var payload = (Arrays.stream(ServerSentEventField.values()) - .map(ServerSentEventField::name) - .map(name -> name.toLowerCase(Locale.ROOT)) - .collect(Collectors.joining("\n")) - + "\n").getBytes(StandardCharsets.UTF_8); + public void testRetryAndIdAreUnimplemented() { + var payload = """ + id: 2 + + retry: 2 + + """.getBytes(StandardCharsets.UTF_8); + + var parser = new ServerSentEventParser(); + var events = parser.parse(payload); + + assertTrue(events.isEmpty()); + } + + public void testEmptyEventDefaultsToMessage() { + var payload = """ + data: hello + + """.getBytes(StandardCharsets.UTF_8); var parser = new ServerSentEventParser(); var events = parser.parse(payload); - assertThat(events.size(), equalTo(ServerSentEventField.values().length)); + assertEvents(events, List.of(new ServerSentEvent("message", "hello"))); + } + + public void testEmptyData() { + var payload = """ + data + data + + """.getBytes(StandardCharsets.UTF_8); + + var parser = new ServerSentEventParser(); + var events = parser.parse(payload); + + assertEvents(events, List.of(new ServerSentEvent("message", "\n"))); } public void testParseDataEventsWithAllEndOfLines() { @@ -51,16 +74,27 @@ public void testParseDataEventsWithAllEndOfLines() { assertEvents( events, List.of( - new ServerSentEvent(ServerSentEventField.EVENT, "message"), - new ServerSentEvent(ServerSentEventField.DATA, "test"), - new ServerSentEvent(ServerSentEventField.EVENT, "message"), - new ServerSentEvent(ServerSentEventField.DATA, "test2"), - new ServerSentEvent(ServerSentEventField.EVENT, "message"), - new ServerSentEvent(ServerSentEventField.DATA, "test3") + new ServerSentEvent("message", "test"), + new ServerSentEvent("message", "test2"), + new ServerSentEvent("message", "test3") ) ); } + public void testParseMultiLineDataEvents() { + var payload = """ + event: message + data: hello + data: there + + """.getBytes(StandardCharsets.UTF_8); + + var parser = new ServerSentEventParser(); + var events = parser.parse(payload); + + assertEvents(events, List.of(new ServerSentEvent("message", "hello\nthere"))); + } + private void assertEvents(Deque actualEvents, List expectedEvents) { assertThat(actualEvents.size(), equalTo(expectedEvents.size())); var expectedEvent = expectedEvents.iterator(); @@ -69,13 +103,13 @@ private void assertEvents(Deque actualEvents, Listevent: message\n\n" - var payload = new byte[] { -17, -69, -65, 101, 118, 101, 110, 116, 58, 32, 109, 101, 115, 115, 97, 103, 101, 10, 10 }; + // these are the bytes for "data: hello\n\n" + var payload = new byte[] { -17, -69, -65, 100, 97, 116, 97, 58, 32, 104, 101, 108, 108, 111, 10, 10 }; var parser = new ServerSentEventParser(); var events = parser.parse(payload); - assertEvents(events, List.of(new ServerSentEvent(ServerSentEventField.EVENT, "message"))); + assertEvents(events, List.of(new ServerSentEvent("message", "hello"))); } public void testEmptyEventIsSetAsEmptyString() { @@ -88,10 +122,7 @@ public void testEmptyEventIsSetAsEmptyString() { var parser = new ServerSentEventParser(); var events = parser.parse(payload); - assertEvents( - events, - List.of(new ServerSentEvent(ServerSentEventField.EVENT, ""), new ServerSentEvent(ServerSentEventField.EVENT, "")) - ); + assertTrue(events.isEmpty()); } public void testCommentsAreIgnored() { @@ -103,7 +134,7 @@ public void testCommentsAreIgnored() { """.getBytes(StandardCharsets.UTF_8)); - assertThat(events.isEmpty(), equalTo(true)); + assertTrue(events.isEmpty()); } public void testCarryOverBytes() { @@ -113,13 +144,13 @@ public void testCarryOverBytes() { event: message data""".getBytes(StandardCharsets.UTF_8)); // no newline after 'data' so the parser won't split the message up - assertEvents(events, List.of(new ServerSentEvent(ServerSentEventField.EVENT, "message"))); + assertTrue(events.isEmpty()); events = parser.parse(""" :test """.getBytes(StandardCharsets.UTF_8)); - assertEvents(events, List.of(new ServerSentEvent(ServerSentEventField.DATA, "test"))); + assertEvents(events, List.of(new ServerSentEvent("test"))); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventProcessorTests.java index 0a0712c69cc3f..30c5aa3e4f7df 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventProcessorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventProcessorTests.java @@ -59,7 +59,7 @@ public void testEmptyParseResponse() { public void testResponse() { ServerSentEventParser parser = mock(); var deque = new ArrayDeque(); - deque.offer(new ServerSentEvent(ServerSentEventField.EVENT, "hello")); + deque.offer(new ServerSentEvent("hello")); when(parser.parse(any())).thenReturn(deque); var processor = new ServerSentEventProcessor(parser); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingInferenceTestUtils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingInferenceTestUtils.java index e0aef58c4f3b3..ba00810d47f4d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingInferenceTestUtils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingInferenceTestUtils.java @@ -19,7 +19,7 @@ public class StreamingInferenceTestUtils { public static Deque events(String... data) { var item = new ArrayDeque(); - Arrays.stream(data).map(datum -> new ServerSentEvent(ServerSentEventField.DATA, datum)).forEach(item::offer); + Arrays.stream(data).map(ServerSentEvent::new).forEach(item::offer); return item; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 71e35aa211d3c..f6f7915e74284 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -546,13 +546,28 @@ public void testInfer_SendsCompletionRequest() throws IOException { public void testInfer_StreamRequest() throws Exception { String responseJson = """ + event: message_start data: {"type": "message_start", "message": {"model": "claude, probably"}} + + event: content_block_start data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} + + event: ping data: {"type": "ping"} + + event: content_block_delta data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} + + event: content_block_delta data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": ", World"}} + + event: content_block_stop data: {"type": "content_block_stop", "index": 0} + + event: message_delta data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 4}} + + event: message_stop data: {"type": "message_stop"} """;