Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(
List<String> input,
@Nullable Consumer<Response> responseConsumerCallback
) throws Exception {
var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId);
var route = randomBoolean() ? "_stream" : "_unified"; // TODO remove unified route
var endpoint = Strings.format("_inference/%s/%s/%s", taskType, modelId, route);
return callAsyncUnified(endpoint, input, "user", responseConsumerCallback);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public final class Paths {
+ INFERENCE_ID
+ "}/_stream";

// TODO remove the _unified path
public static final String UNIFIED_SUFFIX = "_unified";
static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + UNIFIED_SUFFIX;
static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.Scope;
import org.elasticsearch.rest.ServerlessScope;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -50,4 +55,32 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques
protected ActionListener<InferenceAction.Response> listener(RestChannel channel) {
return new ServerSentEventsRestActionListener(channel, threadPool);
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
var params = parseParams(restRequest);
var inferTimeout = parseTimeout(restRequest);

if (params.taskType() == TaskType.CHAT_COMPLETION) {
UnifiedCompletionAction.Request request;
try (var parser = restRequest.contentParser()) {
request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser);
}

return channel -> client.execute(
UnifiedCompletionAction.INSTANCE,
request,
new ServerSentEventsRestActionListener(channel, threadPool)
);
} else {
InferenceAction.Request.Builder requestBuilder;
try (var parser = restRequest.contentParser()) {
requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser);
}

requestBuilder.setInferenceTimeout(inferTimeout);
var request = prepareInferenceRequest(requestBuilder);
return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel));
}
}
}