Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -299,17 +299,13 @@ public void testUnsupportedStream() throws Exception {

try {
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID()), null);
assertThat(events.size(), equalTo(2));
assertThat(events.size(), equalTo(1));
events.forEach(event -> {
switch (event.name()) {
case EVENT -> assertThat(event.value(), equalToIgnoringCase("error"));
case DATA -> assertThat(
event.value(),
containsString(
"Streaming is not allowed for service [streaming_completion_test_service] and task [sparse_embedding]"
)
);
}
assertThat(event.type(), equalToIgnoringCase("error"));
assertThat(
event.data(),
containsString("Streaming is not allowed for service [streaming_completion_test_service] and task [sparse_embedding]")
);
});
} finally {
deleteModel(modelId);
Expand All @@ -331,12 +327,10 @@ public void testSupportedStream() throws Exception {
input.stream().map(s -> s.toUpperCase(Locale.ROOT)).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"),
Stream.of("[DONE]")
).iterator();
assertThat(events.size(), equalTo((input.size() + 1) * 2));
assertThat(events.size(), equalTo(input.size() + 1));
events.forEach(event -> {
switch (event.name()) {
case EVENT -> assertThat(event.value(), equalToIgnoringCase("message"));
case DATA -> assertThat(event.value(), equalTo(expectedResponses.next()));
}
assertThat(event.type(), equalToIgnoringCase("message"));
assertThat(event.data(), equalTo(expectedResponses.next()));
});
} finally {
deleteModel(modelId);
Expand All @@ -359,12 +353,10 @@ public void testUnifiedCompletionInference() throws Exception {
VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER
);
var expectedResponses = expectedResultsIterator(input);
assertThat(events.size(), equalTo((input.size() + 1) * 2));
assertThat(events.size(), equalTo(input.size() + 1));
events.forEach(event -> {
switch (event.name()) {
case EVENT -> assertThat(event.value(), equalToIgnoringCase("message"));
case DATA -> assertThat(event.value(), equalTo(expectedResponses.next()));
}
assertThat(event.type(), equalToIgnoringCase("message"));
assertThat(event.data(), equalTo(expectedResponses.next()));
});
} finally {
deleteModel(modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.XContentFormattedException;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;

import java.io.IOException;
Expand Down Expand Up @@ -364,9 +363,8 @@ private static class RandomStringCollector {
private void collect(String str) throws IOException {
sseParser.parse(str.getBytes(StandardCharsets.UTF_8))
.stream()
.filter(event -> event.name() == ServerSentEventField.DATA)
.filter(ServerSentEvent::hasValue)
.map(ServerSentEvent::value)
.filter(ServerSentEvent::hasData)
.map(ServerSentEvent::data)
.forEach(stringsVerified::offer);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;

import java.io.IOException;
import java.util.ArrayDeque;
Expand Down Expand Up @@ -40,7 +39,7 @@ public static <ParsedChunk> Deque<ParsedChunk> parseEvent(
) throws Exception {
var results = new ArrayDeque<ParsedChunk>(item.size());
for (ServerSentEvent event : item) {
if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
if (event.hasData()) {
try {
var delta = parseFunction.apply(parserConfig, event);
delta.forEachRemaining(results::offer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;

import java.io.IOException;
import java.util.ArrayDeque;
Expand All @@ -42,8 +41,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {

var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size());
for (var event : item) {
if (event.name() == ServerSentEventField.DATA && event.hasValue()) {
try (var parser = parser(event.value())) {
if (event.hasData()) {
try (var parser = parser(event.data())) {
var eventType = eventType(parser);
switch (eventType) {
case "error" -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;

import java.io.IOException;
import java.util.ArrayDeque;
Expand All @@ -37,8 +36,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size());
for (ServerSentEvent event : item) {
if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) {
if (event.hasData()) {
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
var delta = content.apply(jsonParser);
results.offer(new StreamingChatCompletionResults.Result(delta));
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {

private static Iterator<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event)
throws IOException {
if (DONE_MESSAGE.equalsIgnoreCase(event.value())) {
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
return Collections.emptyIterator();
}

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) {
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
moveToFirstToken(jsonParser);

XContentParser.Token token = jsonParser.currentToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;

import java.io.IOException;
import java.util.ArrayDeque;
Expand Down Expand Up @@ -62,7 +61,6 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<

private final BiFunction<String, Exception, Exception> errorParser;
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
private volatile boolean previousEventWasError = false;

public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
this.errorParser = errorParser;
Expand All @@ -83,19 +81,15 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {

var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(item.size());
for (var event : item) {
if (ServerSentEventField.EVENT == event.name() && "error".equals(event.value())) {
previousEventWasError = true;
} else if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
if (previousEventWasError) {
throw errorParser.apply(event.value(), null);
}

if ("error".equals(event.type()) && event.hasData()) {
throw errorParser.apply(event.data(), null);
} else if (event.hasData()) {
try {
var delta = parse(parserConfig, event);
delta.forEachRemaining(results::offer);
} catch (Exception e) {
logger.warn("Failed to parse event from inference provider: {}", event);
throw errorParser.apply(event.value(), e);
throw errorParser.apply(event.data(), e);
}
}
}
Expand All @@ -118,11 +112,11 @@ private static Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChun
XContentParserConfiguration parserConfig,
ServerSentEvent event
) throws IOException {
if (DONE_MESSAGE.equalsIgnoreCase(event.value())) {
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
return Collections.emptyIterator();
}

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) {
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
moveToFirstToken(jsonParser);

XContentParser.Token token = jsonParser.currentToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,26 @@

/**
* Server-Sent Event message: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
* Messages always contain a {@link ServerSentEventField} and a non-null payload value.
* When the stream is parsed and there is no value associated with a {@link ServerSentEventField}, an empty-string is set as the value.
*/
public record ServerSentEvent(ServerSentEventField name, String value) {
public record ServerSentEvent(String type, String data) {

private static final String EMPTY = "";
private static final String MESSAGE = "message";

public ServerSentEvent(ServerSentEventField name) {
this(name, EMPTY);
public static ServerSentEvent empty() {
return new ServerSentEvent(EMPTY, EMPTY);
}

// treat null value as an empty string, don't break parsing
public ServerSentEvent(ServerSentEventField name, String value) {
this.name = name;
this.value = value != null ? value : EMPTY;
public ServerSentEvent(String data) {
this(MESSAGE, data);
}

public boolean hasValue() {
return value.isBlank() == false;
public ServerSentEvent {
data = data != null ? data : EMPTY;
type = type != null && type.isBlank() == false ? type : MESSAGE;
}

public boolean hasData() {
return data.isBlank() == false;
}
}

This file was deleted.

Loading