Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@
package org.elasticsearch.xpack.inference.external.elastic;

import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceErrorResponseEntity;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;

import java.util.Locale;
import java.util.concurrent.Flow;

import static org.elasticsearch.core.Strings.format;

public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, true);
Expand All @@ -29,7 +33,8 @@ public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String reques
@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec
// EIS uses the unified API spec
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));

flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(openAiProcessor);
Expand All @@ -52,4 +57,30 @@ protected Exception buildError(String message, Request request, HttpResult resul
return super.buildError(message, request, result, errorResponse);
}
}

private static Exception buildMidStreamError(Request request, String message, Exception e) {
var errorResponse = ElasticInferenceServiceErrorResponseEntity.fromString(message);
if (errorResponse.errorStructureFound()) {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format(
"%s for request from inference entity id [%s]. Error message: [%s]",
SERVER_ERROR_OBJECT,
request.getInferenceEntityId(),
errorResponse.getErrorMessage()
),
"error",
"stream_error"
);
} else if (e != null) {
return UnifiedChatCompletionException.fromThrowable(e);
} else {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
"error",
"stream_error"
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
Expand All @@ -29,6 +30,8 @@
import java.util.Optional;
import java.util.concurrent.Flow;

import static org.elasticsearch.core.Strings.format;

public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, OpenAiErrorResponse::fromResponse);
Expand All @@ -37,7 +40,7 @@ public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponsePa
@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var openAiProcessor = new OpenAiUnifiedStreamingProcessor();
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));

flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(openAiProcessor);
Expand All @@ -64,6 +67,33 @@ protected Exception buildError(String message, Request request, HttpResult resul
}
}

private static Exception buildMidStreamError(Request request, String message, Exception e) {
var errorResponse = OpenAiErrorResponse.fromString(message);
if (errorResponse instanceof OpenAiErrorResponse oer) {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format(
"%s for request from inference entity id [%s]. Error message: [%s]",
SERVER_ERROR_OBJECT,
request.getInferenceEntityId(),
errorResponse.getErrorMessage()
),
oer.type(),
oer.code(),
oer.param()
);
} else if (e != null) {
return UnifiedChatCompletionException.fromThrowable(e);
} else {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
"stream_error"
);
}
}

private static class OpenAiErrorResponse extends ErrorResponse {
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
"open_ai_error",
Expand Down Expand Up @@ -103,6 +133,19 @@ private static ErrorResponse fromResponse(HttpResult response) {
return ErrorResponse.UNDEFINED_ERROR;
}

private static ErrorResponse fromString(String response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response)
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}

return ErrorResponse.UNDEFINED_ERROR;
}

@Nullable
private final String code;
@Nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;

import java.io.IOException;
import java.util.ArrayDeque;
Expand All @@ -28,6 +29,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.function.BiFunction;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
Expand Down Expand Up @@ -57,7 +59,13 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<S
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
public static final String TOTAL_TOKENS_FIELD = "total_tokens";

private final BiFunction<String, Exception, Exception> errorParser;
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
private volatile boolean previousEventWasError = false;

public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
this.errorParser = errorParser;
}

@Override
protected void upstreamRequest(long n) {
Expand All @@ -71,7 +79,25 @@ protected void upstreamRequest(long n) {
@Override
protected void next(Deque<ServerSentEvent> item) throws Exception {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
var results = parseEvent(item, OpenAiUnifiedStreamingProcessor::parse, parserConfig, logger);

var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(item.size());
for (var event : item) {
if (ServerSentEventField.EVENT == event.name() && "error".equals(event.value())) {
previousEventWasError = true;
} else if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
if (previousEventWasError) {
throw errorParser.apply(event.value(), null);
}

try {
var delta = parse(parserConfig, event);
delta.forEachRemaining(results::offer);
} catch (Exception e) {
logger.warn("Failed to parse event from inference provider: {}", event);
throw errorParser.apply(event.value(), e);
}
}
}

if (results.isEmpty()) {
upstream().request(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,26 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;

import java.io.IOException;

/**
* An example error response would look like
*
* <code>
* {
* "error": "some error"
* }
* </code>
*
*/
public class ElasticInferenceServiceErrorResponseEntity extends ErrorResponse {

private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceErrorResponseEntity.class);
Expand All @@ -24,24 +37,18 @@ private ElasticInferenceServiceErrorResponseEntity(String errorMessage) {
super(errorMessage);
}

/**
* An example error response would look like
*
* <code>
* {
* "error": "some error"
* }
* </code>
*
* @param response The error response
* @return An error entity if the response is JSON with the above structure
* or {@link ErrorResponse#UNDEFINED_ERROR} if the error field wasn't found
*/
public static ErrorResponse fromResponse(HttpResult response) {
try (
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
) {
return fromParser(
() -> XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())
);
}

public static ErrorResponse fromString(String response) {
return fromParser(() -> XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response));
}

private static ErrorResponse fromParser(CheckedSupplier<XContentParser, IOException> jsonParserFactory) {
try (XContentParser jsonParser = jsonParserFactory.get()) {
var responseMap = jsonParser.map();
var error = (String) responseMap.get("error");
if (error != null) {
Expand All @@ -50,7 +57,6 @@ public static ErrorResponse fromResponse(HttpResult response) {
} catch (Exception e) {
logger.debug("Failed to parse error response", e);
}

return ErrorResponse.UNDEFINED_ERROR;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -970,14 +970,51 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
}

public void testUnifiedCompletionError() throws Exception {
testUnifiedStreamError(404, """
{
"error": "The model `rainbow-sprinkles` does not exist or you do not have access to it."
}""", """
{\
"error":{\
"code":"not_found",\
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\
"type":"error"\
}}""");
}

public void testUnifiedCompletionErrorMidStream() throws Exception {
testUnifiedStreamError(200, """
data: { "error": "some error" }

""", """
{\
"error":{\
"code":"stream_error",\
"message":"Received an error response for request from inference entity id [id]. Error message: [some error]",\
"type":"error"\
}}""");
}

public void testUnifiedCompletionMalformedError() throws Exception {
testUnifiedStreamError(200, """
data: { i am not json }

""", """
{\
"error":{\
"code":"bad_request",\
"message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\
at [Source: (String)\\"{ i am not json }\\"; line: 1, column: 3]",\
"type":"x_content_parse_exception"\
}}""");
}

private void testUnifiedStreamError(int responseCode, String responseJson, String expectedJson) throws Exception {
var eisGatewayUrl = getUrl(webServer);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createService(senderFactory, eisGatewayUrl)) {
var responseJson = """
{
"error": "The model `rainbow-sprinkles` does not exist or you do not have access to it."
}""";
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
webServer.enqueue(new MockResponse().setResponseCode(responseCode).setBody(responseJson));
var model = new ElasticInferenceServiceCompletionModel(
"id",
TaskType.COMPLETION,
Expand Down Expand Up @@ -1012,14 +1049,7 @@ public void testUnifiedCompletionError() throws Exception {
});
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());

assertThat(json, is("""
{\
"error":{\
"code":"not_found",\
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\
"type":"error"\
}}"""));
assertThat(json, is(expectedJson));
}
});
}
Expand Down
Loading