Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,22 @@ public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
private final Supplier<List<ChunkInferenceInput>> listSupplier;
private final InputType inputType;

public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
this(input, inputType, false);
}

public EmbeddingsInput(Supplier<List<ChunkInferenceInput>> inputSupplier, @Nullable InputType inputType) {
super(false);
this.listSupplier = Objects.requireNonNull(inputSupplier);
this.inputType = inputType;
}

public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) {
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).collect(Collectors.toList()), inputType, false);
this(input, chunkingSettings, inputType, false);
}

public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) {
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream);
}

public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
this(input, inputType, false);
}

public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType, boolean stream) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,39 +71,37 @@ public void infer(
ActionListener<InferenceServiceResults> listener
) {
init();
var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList();
var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream);
var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
doInfer(model, inferenceInput, taskSettings, timeout, listener);
}

private static InferenceInputs createInput(
SenderService service,
Model model,
List<ChunkInferenceInput> input,
List<String> input,
InputType inputType,
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
boolean stream
) {
List<String> textInput = ChunkInferenceInput.inputs(input);
return switch (model.getTaskType()) {
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(textInput, stream);
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
case RERANK -> {
ValidationException validationException = new ValidationException();
service.validateRerankParameters(returnDocuments, topN, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
yield new QueryAndDocsInputs(query, textInput, returnDocuments, topN, stream);
yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream);
}
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> {
ValidationException validationException = new ValidationException();
service.validateInputType(inputType, model, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
yield new EmbeddingsInput(input, inputType, stream);
yield new EmbeddingsInput(input, null, inputType, stream);
}
default -> throw new ElasticsearchStatusException(
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
Expand Down