Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SlimResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
Expand Down Expand Up @@ -56,6 +57,8 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenizationUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SlimConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SlimConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModelLocation;
Expand Down Expand Up @@ -315,6 +318,20 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
FillMaskConfig::fromXContentStrict
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
LenientlyParsedInferenceConfig.class,
new ParseField(SlimConfig.NAME),
SlimConfig::fromXContentLenient
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
StrictlyParsedInferenceConfig.class,
new ParseField(SlimConfig.NAME),
SlimConfig::fromXContentStrict
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
LenientlyParsedInferenceConfig.class,
Expand Down Expand Up @@ -436,6 +453,13 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
RegressionConfigUpdate::fromXContentStrict
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
InferenceConfigUpdate.class,
new ParseField(SlimConfigUpdate.NAME),
SlimConfigUpdate::fromXContentStrict
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
InferenceConfigUpdate.class,
Expand Down Expand Up @@ -588,6 +612,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceResults.class, PyTorchPassThroughResults.NAME, PyTorchPassThroughResults::new)
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, SlimResults.NAME, SlimResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(
Expand Down Expand Up @@ -619,6 +644,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, NerConfig.NAME, NerConfig::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, FillMaskConfig.NAME, FillMaskConfig::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, SlimConfig.NAME, SlimConfig::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceConfig.class, TextClassificationConfig.NAME, TextClassificationConfig::new)
);
Expand Down Expand Up @@ -658,6 +684,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class, ResultsFieldUpdate.NAME, ResultsFieldUpdate::new)
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class, SlimConfigUpdate.NAME, SlimConfigUpdate::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceConfigUpdate.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ abstract class NlpInferenceResults implements InferenceResults {

abstract void addMapFields(Map<String, Object> map);

public boolean isTruncated() {
return isTruncated;
}

@Override
public final void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(isTruncated);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class SlimResults extends NlpInferenceResults {

public static final String NAME = "slim_result";

public record WeightedToken(int token, float weight) implements Writeable, ToXContentObject {

public static final String TOKEN = "token";
public static final String WEIGHT = "weight";

public WeightedToken(StreamInput in) throws IOException {
this(in.readVInt(), in.readFloat());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(token);
out.writeFloat(weight);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TOKEN, token);
builder.field(WEIGHT, weight);
builder.endObject();
return builder;
}

public Map<String, Object> asMap() {
return Map.of(TOKEN, token, WEIGHT, weight);
}

@Override
public String toString() {
return Strings.toString(this);
}
}

private final String resultsField;
private final List<WeightedToken> weightedTokens;

public SlimResults(String resultField, List<WeightedToken> weightedTokens, boolean isTruncated) {
super(isTruncated);
this.resultsField = resultField;
this.weightedTokens = weightedTokens;
}

public SlimResults(StreamInput in) throws IOException {
super(in);
this.resultsField = in.readString();
this.weightedTokens = in.readList(WeightedToken::new);
}

public List<WeightedToken> getWeightedTokens() {
return weightedTokens;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public String getResultsField() {
return resultsField;
}

@Override
public Object predictedValue() {
throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value");
}

@Override
void doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.startArray(resultsField);
for (var weightedToken : weightedTokens) {
weightedToken.toXContent(builder, params);
}
builder.endArray();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
SlimResults that = (SlimResults) o;
return Objects.equals(resultsField, that.resultsField) && Objects.equals(weightedTokens, that.weightedTokens);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), resultsField, weightedTokens);
}

@Override
void doWriteTo(StreamOutput out) throws IOException {
out.writeString(resultsField);
out.writeList(weightedTokens);
}

@Override
void addMapFields(Map<String, Object> map) {
map.put(resultsField, weightedTokens.stream().map(WeightedToken::asMap).collect(Collectors.toList()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class SlimConfig implements NlpConfig {

public static final String NAME = "slim";

public static SlimConfig fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null);
}

public static SlimConfig fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null);
}

private static final ConstructingObjectParser<SlimConfig, Void> STRICT_PARSER = createParser(false);
private static final ConstructingObjectParser<SlimConfig, Void> LENIENT_PARSER = createParser(true);

private static ConstructingObjectParser<SlimConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<SlimConfig, Void> parser = new ConstructingObjectParser<>(
NAME,
ignoreUnknownFields,
a -> new SlimConfig((VocabularyConfig) a[0], (Tokenization) a[1], (String) a[2])
);
parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> {
if (ignoreUnknownFields == false) {
throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
}, VOCABULARY);
parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION
);
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD);
return parser;
}

private final VocabularyConfig vocabularyConfig;
private final Tokenization tokenization;
private final String resultsField;

public SlimConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization, @Nullable String resultsField) {
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
if (this.tokenization instanceof BertTokenization == false) {
throw ExceptionsHelper.badRequestException(
"SLIM must be configured with BERT tokenizer, [{}] given",
this.tokenization.getName()
);
}
// TODO support spanning
if (this.tokenization.span != -1) {
throw ExceptionsHelper.badRequestException(
"[{}] does not support windowing long text sequences; configured span [{}]",
NAME,
this.tokenization.span
);
}
this.resultsField = resultsField;
}

public SlimConfig(StreamInput in) throws IOException {
vocabularyConfig = new VocabularyConfig(in);
tokenization = in.readNamedWriteable(Tokenization.class);
resultsField = in.readOptionalString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
vocabularyConfig.writeTo(out);
out.writeNamedWriteable(tokenization);
out.writeOptionalString(resultsField);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
if (resultsField != null) {
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
}
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public boolean isTargetTypeSupported(TargetType targetType) {
// TargetType relates to boosted tree models
return false;
}

@Override
public boolean isAllocateOnly() {
return true;
}

@Override
public Version getMinimalSupportedVersion() {
return Version.V_8_7_0;
}

@Override
public String getResultsField() {
return resultsField;
}

@Override
public VocabularyConfig getVocabularyConfig() {
return vocabularyConfig;
}

@Override
public Tokenization getTokenization() {
return tokenization;
}

@Override
public String getName() {
return NAME;
}

@Override
public String toString() {
return Strings.toString(this);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SlimConfig that = (SlimConfig) o;
return Objects.equals(vocabularyConfig, that.vocabularyConfig)
&& Objects.equals(tokenization, that.tokenization)
&& Objects.equals(resultsField, that.resultsField);
}

@Override
public int hashCode() {
return Objects.hash(vocabularyConfig, tokenization, resultsField);
}
}
Loading