diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index bc0ae34aed0cd..df6d564596c0d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SlimResults; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; @@ -56,6 +57,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenizationUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SlimConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SlimConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModelLocation; @@ -315,6 +318,20 @@ public List getNamedXContentParsers() { FillMaskConfig::fromXContentStrict ) ); + namedXContent.add( + new NamedXContentRegistry.Entry( + LenientlyParsedInferenceConfig.class, + new ParseField(SlimConfig.NAME), + SlimConfig::fromXContentLenient + ) + ); + namedXContent.add( + new NamedXContentRegistry.Entry( + StrictlyParsedInferenceConfig.class, + new ParseField(SlimConfig.NAME), + SlimConfig::fromXContentStrict + ) + ); namedXContent.add( new NamedXContentRegistry.Entry( LenientlyParsedInferenceConfig.class, @@ -436,6 +453,13 @@ public List getNamedXContentParsers() { RegressionConfigUpdate::fromXContentStrict ) ); + namedXContent.add( + new NamedXContentRegistry.Entry( + InferenceConfigUpdate.class, + new ParseField(SlimConfigUpdate.NAME), + SlimConfigUpdate::fromXContentStrict + ) + ); namedXContent.add( new NamedXContentRegistry.Entry( InferenceConfigUpdate.class, @@ -588,6 +612,7 @@ public List getNamedWriteables() { namedWriteables.add( new NamedWriteableRegistry.Entry(InferenceResults.class, PyTorchPassThroughResults.NAME, PyTorchPassThroughResults::new) ); + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, SlimResults.NAME, SlimResults::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new)); namedWriteables.add( new NamedWriteableRegistry.Entry( @@ -619,6 +644,7 @@ public List getNamedWriteables() { ); namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, NerConfig.NAME, NerConfig::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, FillMaskConfig.NAME, FillMaskConfig::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, SlimConfig.NAME, SlimConfig::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(InferenceConfig.class, TextClassificationConfig.NAME, TextClassificationConfig::new) ); @@ -658,6 +684,7 @@ public List getNamedWriteables() { namedWriteables.add( new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class, ResultsFieldUpdate.NAME, ResultsFieldUpdate::new) ); + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class, SlimConfigUpdate.NAME, SlimConfigUpdate::new)); namedWriteables.add( new NamedWriteableRegistry.Entry( InferenceConfigUpdate.class, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java index b11efe8ee7ab6..33ef519f5e41c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java @@ -35,6 +35,10 @@ abstract class NlpInferenceResults implements InferenceResults { abstract void addMapFields(Map map); + public boolean isTruncated() { + return isTruncated; + } + @Override public final void writeTo(StreamOutput out) throws IOException { out.writeBoolean(isTruncated); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SlimResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SlimResults.java new file mode 100644 index 0000000000000..beb4e2c790efd --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SlimResults.java @@ -0,0 +1,128 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class SlimResults extends NlpInferenceResults { + + public static final String NAME = "slim_result"; + + public record WeightedToken(int token, float weight) implements Writeable, ToXContentObject { + + public static final String TOKEN = "token"; + public static final String WEIGHT = "weight"; + + public WeightedToken(StreamInput in) throws IOException { + this(in.readVInt(), in.readFloat()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(token); + out.writeFloat(weight); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TOKEN, token); + builder.field(WEIGHT, weight); + builder.endObject(); + return builder; + } + + public Map asMap() { + return Map.of(TOKEN, token, WEIGHT, weight); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + private final String resultsField; + private final List weightedTokens; + + public SlimResults(String resultField, List weightedTokens, boolean isTruncated) { + super(isTruncated); + this.resultsField = resultField; + this.weightedTokens = weightedTokens; + } + + public SlimResults(StreamInput in) throws IOException { + super(in); + this.resultsField = in.readString(); + this.weightedTokens = in.readList(WeightedToken::new); + } + + public List getWeightedTokens() { + return weightedTokens; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getResultsField() { + return resultsField; + } + + @Override + public Object predictedValue() { + throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value"); + } + + @Override + void doXContentBody(XContentBuilder builder, Params params) throws IOException { + builder.startArray(resultsField); + for (var weightedToken : weightedTokens) { + weightedToken.toXContent(builder, params); + } + builder.endArray(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + SlimResults that = (SlimResults) o; + return Objects.equals(resultsField, that.resultsField) && Objects.equals(weightedTokens, that.weightedTokens); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), resultsField, weightedTokens); + } + + @Override + void doWriteTo(StreamOutput out) throws IOException { + out.writeString(resultsField); + out.writeList(weightedTokens); + } + + @Override + void addMapFields(Map map) { + map.put(resultsField, weightedTokens.stream().map(WeightedToken::asMap).collect(Collectors.toList())); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfig.java new file mode 100644 index 0000000000000..a3c6522ef51eb --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfig.java @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; + +import java.io.IOException; +import java.util.Objects; +import java.util.Optional; + +public class SlimConfig implements NlpConfig { + + public static final String NAME = "slim"; + + public static SlimConfig fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null); + } + + public static SlimConfig fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null); + } + + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME, + ignoreUnknownFields, + a -> new SlimConfig((VocabularyConfig) a[0], (Tokenization) a[1], (String) a[2]) + ); + parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> { + if (ignoreUnknownFields == false) { + throw ExceptionsHelper.badRequestException( + "illegal setting [{}] on inference model creation", + VOCABULARY.getPreferredName() + ); + } + return VocabularyConfig.fromXContentLenient(p); + }, VOCABULARY); + parser.declareNamedObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), + TOKENIZATION + ); + parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD); + return parser; + } + + private final VocabularyConfig vocabularyConfig; + private final Tokenization tokenization; + private final String resultsField; + + public SlimConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization, @Nullable String resultsField) { + this.vocabularyConfig = Optional.ofNullable(vocabularyConfig) + .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore())); + this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; + if (this.tokenization instanceof BertTokenization == false) { + throw ExceptionsHelper.badRequestException( + "SLIM must be configured with BERT tokenizer, [{}] given", + this.tokenization.getName() + ); + } + // TODO support spanning + if (this.tokenization.span != -1) { + throw ExceptionsHelper.badRequestException( + "[{}] does not support windowing long text sequences; configured span [{}]", + NAME, + this.tokenization.span + ); + } + this.resultsField = resultsField; + } + + public SlimConfig(StreamInput in) throws IOException { + vocabularyConfig = new VocabularyConfig(in); + tokenization = in.readNamedWriteable(Tokenization.class); + resultsField = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + vocabularyConfig.writeTo(out); + out.writeNamedWriteable(tokenization); + out.writeOptionalString(resultsField); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params); + NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public boolean isTargetTypeSupported(TargetType targetType) { + // TargetType relates to boosted tree models + return false; + } + + @Override + public boolean isAllocateOnly() { + return true; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.V_8_7_0; + } + + @Override + public String getResultsField() { + return resultsField; + } + + @Override + public VocabularyConfig getVocabularyConfig() { + return vocabularyConfig; + } + + @Override + public Tokenization getTokenization() { + return tokenization; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public String toString() { + return Strings.toString(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SlimConfig that = (SlimConfig) o; + return Objects.equals(vocabularyConfig, that.vocabularyConfig) + && Objects.equals(tokenization, that.tokenization) + && Objects.equals(resultsField, that.resultsField); + } + + @Override + public int hashCode() { + return Objects.hash(vocabularyConfig, tokenization, resultsField); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfigUpdate.java new file mode 100644 index 0000000000000..5495e02913923 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfigUpdate.java @@ -0,0 +1,180 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.TOKENIZATION; + +public class SlimConfigUpdate extends NlpConfigUpdate { + + public static final String NAME = SlimConfig.NAME; + + public static SlimConfigUpdate fromMap(Map map) { + Map options = new HashMap<>(map); + String resultsField = (String) options.remove(RESULTS_FIELD.getPreferredName()); + TokenizationUpdate tokenizationUpdate = NlpConfigUpdate.tokenizationFromMap(options); + + if (options.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet()); + } + return new SlimConfigUpdate(resultsField, tokenizationUpdate); + } + + private static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>(NAME, lenient, SlimConfigUpdate.Builder::new); + parser.declareString(SlimConfigUpdate.Builder::setResultsField, RESULTS_FIELD); + parser.declareNamedObject( + SlimConfigUpdate.Builder::setTokenizationUpdate, + (p, c, n) -> p.namedObject(TokenizationUpdate.class, n, lenient), + TOKENIZATION + ); + return parser; + } + + public static SlimConfigUpdate fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); + } + + private final String resultsField; + + public SlimConfigUpdate(String resultsField, TokenizationUpdate tokenizationUpdate) { + super(tokenizationUpdate); + this.resultsField = resultsField; + } + + public SlimConfigUpdate(StreamInput in) throws IOException { + super(in); + this.resultsField = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(resultsField); + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public InferenceConfig apply(InferenceConfig originalConfig) { + if (originalConfig instanceof SlimConfig == false) { + throw ExceptionsHelper.badRequestException( + "Inference config of type [{}] can not be updated with a request of type [{}]", + originalConfig.getName(), + getName() + ); + } + SlimConfig SlimConfig = (SlimConfig) originalConfig; + if (isNoop(SlimConfig)) { + return SlimConfig; + } + + return new SlimConfig( + SlimConfig.getVocabularyConfig(), + (tokenizationUpdate == null) ? SlimConfig.getTokenization() : tokenizationUpdate.apply(SlimConfig.getTokenization()), + Optional.ofNullable(resultsField).orElse(SlimConfig.getResultsField()) + ); + } + + boolean isNoop(SlimConfig originalConfig) { + return (this.resultsField == null || this.resultsField.equals(originalConfig.getResultsField())) && super.isNoop(); + } + + @Override + public boolean isSupported(InferenceConfig config) { + return config instanceof SlimConfig; + } + + @Override + public String getResultsField() { + return resultsField; + } + + @Override + public InferenceConfigUpdate.Builder, ? extends InferenceConfigUpdate> newBuilder() { + return new SlimConfigUpdate.Builder().setResultsField(resultsField).setTokenizationUpdate(tokenizationUpdate); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SlimConfigUpdate that = (SlimConfigUpdate) o; + return Objects.equals(resultsField, that.resultsField) && Objects.equals(tokenizationUpdate, that.tokenizationUpdate); + } + + @Override + public int hashCode() { + return Objects.hash(resultsField, tokenizationUpdate); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.V_8_7_0; + } + + public static class Builder implements InferenceConfigUpdate.Builder { + private String resultsField; + private TokenizationUpdate tokenizationUpdate; + + @Override + public SlimConfigUpdate.Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public SlimConfigUpdate.Builder setTokenizationUpdate(TokenizationUpdate tokenizationUpdate) { + this.tokenizationUpdate = tokenizationUpdate; + return this; + } + + @Override + public SlimConfigUpdate build() { + return new SlimConfigUpdate(resultsField, tokenizationUpdate); + } + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java index c9f978cc2bb6b..c79dd4f3734d3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -27,6 +27,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdateTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SlimConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SlimConfigUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; @@ -113,6 +115,8 @@ protected Request mutateInstanceForVersion(Request instance, Version version) { adjustedUpdate = PassThroughConfigUpdateTests.mutateForVersion(update, version); } else if (nlpConfigUpdate instanceof QuestionAnsweringConfigUpdate update) { adjustedUpdate = QuestionAnsweringConfigUpdateTests.mutateForVersion(update, version); + } else if (nlpConfigUpdate instanceof SlimConfigUpdate update) { + adjustedUpdate = SlimConfigUpdateTests.mutateForVersion(update, version); } else { throw new IllegalArgumentException("Unknown update [" + currentUpdate.getName() + "]"); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/SlimResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/SlimResultsTests.java new file mode 100644 index 0000000000000..a83d59ea97622 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/SlimResultsTests.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class SlimResultsTests extends InferenceResultsTestCase { + @Override + protected Writeable.Reader instanceReader() { + return SlimResults::new; + } + + @Override + protected SlimResults createTestInstance() { + int numTokens = randomIntBetween(0, 20); + List tokenList = new ArrayList<>(); + for (int i = 0; i < numTokens; i++) { + tokenList.add(new SlimResults.WeightedToken(i, (float) randomDoubleBetween(0.0, 5.0, false))); + } + return new SlimResults(randomAlphaOfLength(4), tokenList, randomBoolean()); + } + + @Override + @SuppressWarnings("unchecked") + void assertFieldValues(SlimResults createdInstance, IngestDocument document, String resultsField) { + var ingestedTokens = (List>) document.getFieldValue( + resultsField + '.' + createdInstance.getResultsField(), + List.class + ); + var originalTokens = createdInstance.getWeightedTokens(); + assertEquals(originalTokens.size(), ingestedTokens.size()); + for (int i = 0; i < createdInstance.getWeightedTokens().size(); i++) { + assertEquals(originalTokens.get(i).token(), (int) ingestedTokens.get(i).get("token")); + assertEquals(originalTokens.get(i).weight(), (float) ingestedTokens.get(i).get("weight"), 0.0001); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/AbstractNlpConfigUpdateTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/AbstractNlpConfigUpdateTestCase.java index c58146368252d..dc3335d8407ff 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/AbstractNlpConfigUpdateTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/AbstractNlpConfigUpdateTestCase.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.core.Tuple; +import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; @@ -29,6 +30,16 @@ protected NamedWriteableRegistry writableRegistry() { return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); } + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); + } + /** * @param expectedTokenization The tokenization update that will be provided * @return A map and expected resulting object. Note: `tokenization` will be overwritten if provided in the returned map diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java index ed174655e6c99..beed7b1756c1e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java @@ -8,12 +8,9 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.Version; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; -import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import java.io.IOException; import java.util.HashMap; @@ -131,14 +128,4 @@ protected FillMaskConfigUpdate createTestInstance() { protected FillMaskConfigUpdate mutateInstanceForVersion(FillMaskConfigUpdate instance, Version version) { return mutateForVersion(instance, version); } - - @Override - protected NamedXContentRegistry xContentRegistry() { - return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); - } - - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java index d1089330287fc..095a6b1f1d4d8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java @@ -8,12 +8,9 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.Version; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; -import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import java.io.IOException; import java.util.HashMap; @@ -111,14 +108,4 @@ protected NerConfigUpdate createTestInstance() { protected NerConfigUpdate mutateInstanceForVersion(NerConfigUpdate instance, Version version) { return mutateForVersion(instance, version); } - - @Override - protected NamedXContentRegistry xContentRegistry() { - return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); - } - - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java index 002b9da8a5223..d5871fc19acdb 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java @@ -8,12 +8,9 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.Version; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; -import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import java.io.IOException; import java.util.HashMap; @@ -101,14 +98,4 @@ protected PassThroughConfigUpdate createTestInstance() { protected PassThroughConfigUpdate mutateInstanceForVersion(PassThroughConfigUpdate instance, Version version) { return mutateForVersion(instance, version); } - - @Override - protected NamedXContentRegistry xContentRegistry() { - return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); - } - - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdateTests.java index c6786b4e4d097..9bb107d93967a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdateTests.java @@ -8,12 +8,9 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.Version; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; -import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import java.io.IOException; import java.util.HashMap; @@ -171,14 +168,4 @@ public void testApply() { public static QuestionAnsweringConfigUpdate createRandom() { return randomUpdate(); } - - @Override - protected NamedXContentRegistry xContentRegistry() { - return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); - } - - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfigTests.java new file mode 100644 index 0000000000000..edfa63df4e4ea --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfigTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase; + +import java.io.IOException; + +public class SlimConfigTests extends InferenceConfigItemTestCase { + + public static SlimConfig createRandom() { + // create a tokenization config with a no span setting. + var tokenization = new BertTokenization( + randomBoolean() ? null : randomBoolean(), + randomBoolean() ? null : randomBoolean(), + randomBoolean() ? null : randomIntBetween(1, 1024), + randomBoolean() ? null : randomFrom(Tokenization.Truncate.FIRST), + null + ); + + return new SlimConfig( + randomBoolean() ? null : VocabularyConfigTests.createRandom(), + randomBoolean() ? null : tokenization, + randomBoolean() ? null : randomAlphaOfLength(5) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return SlimConfig::new; + } + + @Override + protected SlimConfig createTestInstance() { + return createRandom(); + } + + @Override + protected SlimConfig doParseInstance(XContentParser parser) throws IOException { + return SlimConfig.fromXContentLenient(parser); + } + + @Override + protected SlimConfig mutateInstanceForVersion(SlimConfig instance, Version version) { + return instance; + } + + public void testBertTokenizationOnly() { + ElasticsearchStatusException e = expectThrows( + ElasticsearchStatusException.class, + () -> new SlimConfig(null, RobertaTokenizationTests.createRandom(), null) + ); + assertEquals(RestStatus.BAD_REQUEST, e.status()); + assertEquals("SLIM must be configured with BERT tokenizer, [roberta] given", e.getMessage()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfigUpdateTests.java new file mode 100644 index 0000000000000..96ad9d6ddf142 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfigUpdateTests.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +public class SlimConfigUpdateTests extends AbstractNlpConfigUpdateTestCase { + + public static SlimConfigUpdate randomUpdate() { + SlimConfigUpdate.Builder builder = new SlimConfigUpdate.Builder(); + if (randomBoolean()) { + builder.setResultsField(randomAlphaOfLength(8)); + } + if (randomBoolean()) { + builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)); + } + return builder.build(); + } + + public static SlimConfigUpdate mutateForVersion(SlimConfigUpdate instance, Version version) { + if (version.before(Version.V_8_1_0)) { + return new SlimConfigUpdate(instance.getResultsField(), null); + } + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return SlimConfigUpdate::new; + } + + @Override + protected SlimConfigUpdate createTestInstance() { + return randomUpdate(); + } + + @Override + protected SlimConfigUpdate doParseInstance(XContentParser parser) throws IOException { + return SlimConfigUpdate.fromXContentStrict(parser); + } + + @Override + protected SlimConfigUpdate mutateInstanceForVersion(SlimConfigUpdate instance, Version version) { + return mutateForVersion(instance, version); + } + + @Override + Tuple, SlimConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) { + SlimConfigUpdate expected = new SlimConfigUpdate("ml-results", expectedTokenization); + Map config = new HashMap<>() { + { + put(NlpConfig.RESULTS_FIELD.getPreferredName(), "ml-results"); + } + }; + return Tuple.tuple(config, expected); + } + + @Override + SlimConfigUpdate fromMap(Map map) { + return SlimConfigUpdate.fromMap(map); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java index a8c2c871e488b..d4bdc8ca92f26 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java @@ -9,12 +9,9 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.Version; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; -import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import java.io.IOException; import java.util.HashMap; @@ -195,14 +192,4 @@ protected TextClassificationConfigUpdate createTestInstance() { protected TextClassificationConfigUpdate mutateInstanceForVersion(TextClassificationConfigUpdate instance, Version version) { return mutateForVersion(instance, version); } - - @Override - protected NamedXContentRegistry xContentRegistry() { - return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); - } - - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java index 987722e291afe..6c9377b8d46a1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java @@ -8,12 +8,9 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.Version; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; -import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import java.io.IOException; import java.util.HashMap; @@ -101,14 +98,4 @@ protected TextEmbeddingConfigUpdate createTestInstance() { protected TextEmbeddingConfigUpdate mutateInstanceForVersion(TextEmbeddingConfigUpdate instance, Version version) { return mutateForVersion(instance, version); } - - @Override - protected NamedXContentRegistry xContentRegistry() { - return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); - } - - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextSimilarityConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextSimilarityConfigUpdateTests.java index 5c8d21ae99fe2..5972c2a902da7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextSimilarityConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextSimilarityConfigUpdateTests.java @@ -8,12 +8,9 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.Version; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; -import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import java.io.IOException; import java.util.Arrays; @@ -169,14 +166,4 @@ public void testApply() { public static TextSimilarityConfigUpdate createRandom() { return randomUpdate(); } - - @Override - protected NamedXContentRegistry xContentRegistry() { - return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); - } - - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java index 2d424edac4c94..9a4c5d8a46ecb 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java @@ -8,12 +8,9 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.Version; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; -import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import java.io.IOException; import java.util.HashMap; @@ -212,14 +209,4 @@ public void testIsNoop() { public static ZeroShotClassificationConfigUpdate createRandom() { return randomUpdate(); } - - @Override - protected NamedXContentRegistry xContentRegistry() { - return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); - } - - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 18685aafa6a6c..873f24457a527 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -41,6 +41,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SlimConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SlimConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig; @@ -306,6 +308,9 @@ InferenceConfigUpdate inferenceConfigUpdateFromMap(Map configMap } else if (configMap.containsKey(RegressionConfig.NAME.getPreferredName())) { checkSupportedVersion(RegressionConfig.EMPTY_PARAMS); return RegressionConfigUpdate.fromMap(valueMap); + } else if (configMap.containsKey(SlimConfig.NAME)) { + checkNlpSupported(SlimConfig.NAME); + return SlimConfigUpdate.fromMap(valueMap); } else if (configMap.containsKey(TextClassificationConfig.NAME)) { checkNlpSupported(TextClassificationConfig.NAME); return TextClassificationConfigUpdate.fromMap(valueMap); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/SlimProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/SlimProcessor.java new file mode 100644 index 0000000000000..7edb052ba18af --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/SlimProcessor.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SlimResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; +import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; + +public class SlimProcessor extends NlpTask.Processor { + + private final NlpTask.RequestBuilder requestBuilder; + + public SlimProcessor(NlpTokenizer tokenizer) { + super(tokenizer); + this.requestBuilder = tokenizer.requestBuilder(); + } + + @Override + public void validateInputs(List inputs) {} + + @Override + public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) { + return requestBuilder; + } + + @Override + public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) { + return (tokenization, pyTorchResult) -> processResult(tokenization, pyTorchResult, config.getResultsField()); + } + + static InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult, String resultsField) { + // Convert the verbose results to the sparse format. + // Anything with a score > 0.0 is retained. + List weightedTokens = new ArrayList<>(); + double[] weights = pyTorchResult.getInferenceResult()[0][0]; + for (int i = 0; i < weights.length; i++) { + if (weights[i] > 0.0) { + weightedTokens.add(new SlimResults.WeightedToken(i, (float) weights[i])); + } + } + + return new SlimResults( + Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD), + weightedTokens, + tokenization.anyTruncated() + ); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java index a1fca461a7381..4f86e10e9d7cf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java @@ -64,6 +64,12 @@ public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig confi public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) { return new TextSimilarityProcessor(tokenizer); } + }, + SLIM { + @Override + public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) { + return new SlimProcessor(tokenizer); + } }; public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/SlimProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/SlimProcessorTests.java new file mode 100644 index 0000000000000..ad07b25b65415 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/SlimProcessorTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.SlimResults; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; +import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult; + +import java.util.List; + +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; + +public class SlimProcessorTests extends ESTestCase { + + public void testProcessResult() { + double[][][] pytorchResult = new double[][][] { { { 0.0, 1.0, 0.0, 3.0, 4.0, 0.0, 0.0 } } }; + + TokenizationResult tokenizationResult = new BertTokenizationResult(List.of(), List.of(), 0); + + var inferenceResult = SlimProcessor.processResult(tokenizationResult, new PyTorchInferenceResult(pytorchResult), "foo"); + assertThat(inferenceResult, instanceOf(SlimResults.class)); + var slimResults = (SlimResults) inferenceResult; + assertEquals(slimResults.getResultsField(), "foo"); + + var weightedTokens = slimResults.getWeightedTokens(); + assertThat(weightedTokens, hasSize(3)); + assertEquals(new SlimResults.WeightedToken(1, 1.0f), weightedTokens.get(0)); + assertEquals(new SlimResults.WeightedToken(3, 3.0f), weightedTokens.get(1)); + assertEquals(new SlimResults.WeightedToken(4, 4.0f), weightedTokens.get(2)); + } +}