Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Apply windowing and chunking to long documents #104363

Merged
merged 24 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/104363.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 104363
summary: Apply windowing and chunking to long documents
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ static TransportVersion def(int id) {
public static final TransportVersion HOT_THREADS_AS_BYTES = def(8_571_00_0);
public static final TransportVersion ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED = def(8_572_00_0);
public static final TransportVersion ESQL_ENRICH_POLICY_CCQ_MODE = def(8_573_00_0);
public static final TransportVersion NLP_DOCUMENT_CHUNKING_ADDED = def(8_574_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

public class ChunkedInferenceAction extends ActionType<ChunkedInferenceAction.Response> {

public static final ChunkedInferenceAction INSTANCE = new ChunkedInferenceAction();
public static final String NAME = "cluster:internal/xpack/ml/chunkedinference";

static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, ChunkedInferenceAction.Request.Builder::new);
static {
PARSER.declareStringArray(Request.Builder::setInputs, new ParseField("inputs"));
PARSER.declareInt(Request.Builder::setWindowSize, new ParseField("window_size"));
PARSER.declareInt(Request.Builder::setSpan, new ParseField("span"));
}

public static Request parseRequest(String id, TimeValue timeout, XContentParser parser) {
Request.Builder builder = PARSER.apply(parser, null);
if (id != null) {
builder.setId(id);
}
if (timeout != null) {
builder.setTimeout(timeout);
}
return builder.build();
}

public ChunkedInferenceAction() {
super(NAME, Response::new);
}

public static class Request extends ActionRequest {
public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(10);

private final String modelId;
private final List<String> inputs;
private final Integer windowSize;
private final Integer span;
private final TimeValue timeout;

public Request(String modelId, List<String> inputs, @Nullable Integer windowSize, @Nullable Integer span, TimeValue timeout) {
this.modelId = modelId;
this.inputs = inputs;
this.windowSize = windowSize;
this.span = span;
this.timeout = timeout;
}

public Request(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.inputs = in.readStringCollectionAsList();
this.windowSize = in.readOptionalVInt();
this.span = in.readOptionalVInt();
this.timeout = in.readTimeValue();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeStringCollection(inputs);
out.writeOptionalVInt(windowSize);
out.writeOptionalVInt(span);
out.writeTimeValue(timeout);
}

public boolean containsWindowOptions() {
return span != null || windowSize != null;
}

@Override
public ActionRequestValidationException validate() {
if (windowSize == null || span == null) {
// need both field for validation
return null;
}

var failedValidation = new ActionRequestValidationException();

if (span <= 0 || windowSize <= 0) {
failedValidation.addValidationError("window size and overlap must both be greater than 0");
}
if (span >= windowSize) {
failedValidation.addValidationError("span must be less than window size");
}

return failedValidation.validationErrors().isEmpty() ? null : failedValidation;
}

public String getModelId() {
return modelId;
}

public Integer getSpan() {
return span;
}

public Integer getWindowSize() {
return windowSize;
}

public List<String> getInputs() {
return inputs;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return span == request.span
davidkyle marked this conversation as resolved.
Show resolved Hide resolved
&& Objects.equals(modelId, request.modelId)
&& Objects.equals(inputs, request.inputs)
&& Objects.equals(windowSize, request.windowSize)
&& Objects.equals(timeout, request.timeout);
}

@Override
public int hashCode() {
return Objects.hash(modelId, windowSize, span, inputs, timeout);
}

public static class Builder {

private String id;
private List<String> inputs;
private Integer span = null;
private Integer windowSize = null;
private TimeValue timeout = DEFAULT_TIMEOUT;

public Builder setId(String id) {
this.id = id;
return this;
}

public Builder setInputs(List<String> inputs) {
this.inputs = inputs;
return this;
}

public Builder setWindowSize(int windowSize) {
this.windowSize = windowSize;
return this;
}

public Builder setSpan(int span) {
this.span = span;
return this;
}

public Builder setTimeout(TimeValue timeout) {
this.timeout = timeout;
return this;
}

public Request build() {
return new Request(id, inputs, windowSize, span, timeout);
}
}
}

public static class Response extends ActionResponse implements ToXContentObject {

private final List<InferenceResults> inferenceResults;

public Response(List<InferenceResults> inferenceResults) {
super();
this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults"));
}

public Response(StreamInput in) throws IOException {
super(in);
this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableCollectionAsList(InferenceResults.class));
}

public List<InferenceResults> getInferenceResults() {
return inferenceResults;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteableCollection(inferenceResults);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.startArray("inference_results");
for (var inference : inferenceResults) {
// inference results implement ToXContentFragment
builder.startObject();
inference.toXContent(builder, params);
builder.endObject();
}
builder.endArray();
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Response that = (Response) o;
return Objects.equals(inferenceResults, that.inferenceResults);
}

@Override
public int hashCode() {
return Objects.hash(inferenceResults);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public static Builder parseRequest(String id, XContentParser parser) {
private final List<String> textInput;
private boolean highPriority;
private TrainedModelPrefixStrings.PrefixType prefixType = TrainedModelPrefixStrings.PrefixType.NONE;
private boolean chunkResults = false;

/**
* Build a request from a list of documents as maps.
Expand Down Expand Up @@ -197,6 +198,11 @@ public Request(StreamInput in) throws IOException {
} else {
prefixType = TrainedModelPrefixStrings.PrefixType.NONE;
}
if (in.getTransportVersion().onOrAfter(TransportVersions.NLP_DOCUMENT_CHUNKING_ADDED)) {
chunkResults = in.readBoolean();
} else {
chunkResults = false;
}
}

public int numberOfDocuments() {
Expand Down Expand Up @@ -243,6 +249,14 @@ public void setPrefixType(TrainedModelPrefixStrings.PrefixType prefixType) {
this.prefixType = prefixType;
}

public boolean isChunkResults() {
return chunkResults;
}

public void setChunkResults(boolean chunkResults) {
this.chunkResults = chunkResults;
}

public TrainedModelPrefixStrings.PrefixType getPrefixType() {
return prefixType;
}
Expand Down Expand Up @@ -271,6 +285,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_TRAINED_MODEL_PREFIX_STRINGS_ADDED)) {
out.writeEnum(prefixType);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.NLP_DOCUMENT_CHUNKING_ADDED)) {
out.writeBoolean(chunkResults);
}
}

@Override
Expand All @@ -285,7 +302,8 @@ public boolean equals(Object o) {
&& Objects.equals(objectsToInfer, that.objectsToInfer)
&& Objects.equals(textInput, that.textInput)
&& (highPriority == that.highPriority)
&& (prefixType == that.prefixType);
&& (prefixType == that.prefixType)
&& (chunkResults == that.chunkResults);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ public static Request.Builder parseRequest(String id, XContentParser parser) {
// input and so cannot construct a document.
private final List<String> textInput;
private TrainedModelPrefixStrings.PrefixType prefixType = TrainedModelPrefixStrings.PrefixType.NONE;
private boolean chunkResults = false;

public static Request forDocs(String id, InferenceConfigUpdate update, List<Map<String, Object>> docs, TimeValue inferenceTimeout) {
return new Request(
Expand Down Expand Up @@ -163,6 +164,11 @@ public Request(StreamInput in) throws IOException {
} else {
prefixType = TrainedModelPrefixStrings.PrefixType.NONE;
}
if (in.getTransportVersion().onOrAfter(TransportVersions.NLP_DOCUMENT_CHUNKING_ADDED)) {
chunkResults = in.readBoolean();
} else {
chunkResults = false;
}
}

public String getId() {
Expand Down Expand Up @@ -215,6 +221,14 @@ public TrainedModelPrefixStrings.PrefixType getPrefixType() {
return prefixType;
}

public boolean isChunkResults() {
return chunkResults;
}

public void setChunkResults(boolean chunkResults) {
this.chunkResults = chunkResults;
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException validationException = super.validate();
Expand Down Expand Up @@ -244,6 +258,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_TRAINED_MODEL_PREFIX_STRINGS_ADDED)) {
out.writeEnum(prefixType);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.NLP_DOCUMENT_CHUNKING_ADDED)) {
out.writeBoolean(chunkResults);
}
}

@Override
Expand All @@ -262,12 +279,13 @@ public boolean equals(Object o) {
&& Objects.equals(inferenceTimeout, that.inferenceTimeout)
&& Objects.equals(highPriority, that.highPriority)
&& Objects.equals(textInput, that.textInput)
&& (prefixType == that.prefixType);
&& (prefixType == that.prefixType)
&& (chunkResults == that.chunkResults);
}

@Override
public int hashCode() {
return Objects.hash(id, update, docs, inferenceTimeout, highPriority, textInput, prefixType);
return Objects.hash(id, update, docs, inferenceTimeout, highPriority, textInput, prefixType, chunkResults);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
Expand Down Expand Up @@ -671,6 +673,13 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
TextSimilarityInferenceResults::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceResults.class, ChunkedTextEmbeddingResults.NAME, ChunkedTextEmbeddingResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceResults.class, ChunkedTextExpansionResults.NAME, ChunkedTextExpansionResults::new)
);

// Inference Configs
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new)
Expand Down