Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Apply windowing and chunking to long documents #104363

Merged
merged 24 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ static TransportVersion def(int id) {
public static final TransportVersion PEERFINDER_REPORTS_PEERS_MASTERS = def(8_575_00_0);
public static final TransportVersion ESQL_MULTI_CLUSTERS_ENRICH = def(8_576_00_0);
public static final TransportVersion NESTED_KNN_MORE_INNER_HITS = def(8_577_00_0);
public static final TransportVersion NLP_DOCUMENT_CHUNKING_ADDED = def(8_578_00_0);
public static final TransportVersion REQUIRE_DATA_STREAM_ADDED = def(8_578_00_0);
public static final TransportVersion ML_INFERENCE_COHERE_EMBEDDINGS_ADDED = def(8_579_00_0);
public static final TransportVersion DESIRED_NODE_VERSION_OPTIONAL_STRING = def(8_580_00_0);
public static final TransportVersion ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED = def(8_581_00_0);
public static final TransportVersion NLP_DOCUMENT_CHUNKING_ADDED = def(8_582_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.inference;

public interface ChunkedInferenceServiceResults extends InferenceServiceResults {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.inference;

import org.elasticsearch.core.Nullable;

public record ChunkingOptions(@Nullable Integer windowSize, @Nullable Integer span) {

public boolean settingsArePresent() {
return windowSize != null || span != null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ default void init(Client client) {}
* @param model The model
* @param input Inference input
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param listener Inference result listener
*/
void infer(
Expand All @@ -86,6 +87,27 @@ void infer(
ActionListener<InferenceServiceResults> listener
);

/**
* Chunk long text according to {@code chunkingOptions} or the
* model defaults if {@code chunkingOptions} contains unset
* values.
*
* @param model The model
* @param input Inference input
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param chunkingOptions The window and span options to apply
* @param listener Inference result listener
*/
void chunkedInfer(
Model model,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
ChunkingOptions chunkingOptions,
ActionListener<ChunkedInferenceServiceResults> listener
);

/**
* Start or prepare the model for use.
* @param model The model
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class ChunkedSparseEmbeddingResults implements ChunkedInferenceServiceResults {

public static final String NAME = "chunked_sparse_embedding_results";

public static ChunkedSparseEmbeddingResults ofMlResult(ChunkedTextExpansionResults mlInferenceResults) {
return new ChunkedSparseEmbeddingResults(mlInferenceResults.getChunks());
}

private final List<ChunkedTextExpansionResults.ChunkedResult> chunkedResults;

public ChunkedSparseEmbeddingResults(List<ChunkedTextExpansionResults.ChunkedResult> chunks) {
this.chunkedResults = chunks;
}

public ChunkedSparseEmbeddingResults(StreamInput in) throws IOException {
this.chunkedResults = in.readCollectionAsList(ChunkedTextExpansionResults.ChunkedResult::new);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray("sparse_embedding_chunk");
for (ChunkedTextExpansionResults.ChunkedResult chunk : chunkedResults) {
chunk.toXContent(builder, params);
}
builder.endArray();
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(chunkedResults);
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("Chunked results are not returned in the coordindated action");
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
throw new UnsupportedOperationException("Chunked results are not returned in the legacy format");
}

@Override
public Map<String, Object> asMap() {
throw new UnsupportedOperationException("Chunked results are not returned in the a map format");
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ChunkedSparseEmbeddingResults that = (ChunkedSparseEmbeddingResults) o;
return Objects.equals(chunkedResults, that.chunkedResults);
}

@Override
public int hashCode() {
return Objects.hash(chunkedResults);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class ChunkedTextEmbeddingResults implements ChunkedInferenceServiceResults {

public static final String NAME = "chunked_text_embedding_service_results";

public static ChunkedTextEmbeddingResults ofMlResult(
org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults mlInferenceResults
) {
return new ChunkedTextEmbeddingResults(mlInferenceResults.getChunks());
}

private final List<org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk> chunks;

public ChunkedTextEmbeddingResults(
List<org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk> chunks
) {
this.chunks = chunks;
}

public ChunkedTextEmbeddingResults(StreamInput in) throws IOException {
this.chunks = in.readCollectionAsList(
org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk::new
);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray("text_embedding_chunk");
for (var embedding : chunks) {
embedding.toXContent(builder, params);
}
builder.endArray();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(chunks);
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("Chunked results are not returned in the coordindated action");
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
throw new UnsupportedOperationException("Chunked results are not returned in the legacy format");
}

@Override
public Map<String, Object> asMap() {
throw new UnsupportedOperationException("Chunked results are not returned in the a map format");
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ChunkedTextEmbeddingResults that = (ChunkedTextEmbeddingResults) o;
return Objects.equals(chunks, that.chunks);
}

@Override
public int hashCode() {
return Objects.hash(chunks);
}
}