Skip to content

Commit

Permalink
[ML] update truncation default & adding field output when input is tr…
Browse files Browse the repository at this point in the history
…uncated (#79942) (#80022)

This commit makes the two following changes (along with some
refactoring)  - Nlp results will now indicate if the input was truncated
or not  - The default truncation is now `none` instead of `first`
  • Loading branch information
benwtrent committed Oct 28, 2021
1 parent 8674678 commit 44dc322
Show file tree
Hide file tree
Showing 25 changed files with 429 additions and 113 deletions.
2 changes: 1 addition & 1 deletion docs/reference/ml/ml-shared.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ end::inference-config-nlp-tokenization-bert-do-lower-case[]

tag::inference-config-nlp-tokenization-bert-truncate[]
Indicates how tokens are truncated when they exceed `max_sequence_length`.
The default value is `first`.
The default value is `none`.
+
--
* `none`: No truncation occurs; the inference request receives an error.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
Expand Down Expand Up @@ -498,7 +499,13 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
new NamedWriteableRegistry.Entry(InferenceResults.class, PyTorchPassThroughResults.NAME, PyTorchPassThroughResults::new)
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceResults.class,
NlpClassificationInferenceResults.NAME,
NlpClassificationInferenceResults::new
)
);
// Inference Configs
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,27 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class FillMaskResults extends ClassificationInferenceResults {
public class FillMaskResults extends NlpClassificationInferenceResults {

public static final String NAME = "fill_mask_result";

private final String predictedSequence;

public FillMaskResults(
double value,
String classificationLabel,
String predictedSequence,
List<TopClassEntry> topClasses,
String topNumClassesField,
String resultsField,
Double predictionProbability
Double predictionProbability,
boolean isTruncated
) {
super(
value,
classificationLabel,
topClasses,
List.of(),
topNumClassesField,
resultsField,
PredictionFieldType.STRING,
0,
predictionProbability,
null
);
super(classificationLabel, topClasses, resultsField, predictionProbability, isTruncated);
this.predictedSequence = predictedSequence;
}

Expand All @@ -54,8 +40,8 @@ public FillMaskResults(StreamInput in) throws IOException {
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
public void doWriteTo(StreamOutput out) throws IOException {
super.doWriteTo(out);
out.writeString(predictedSequence);
}

Expand All @@ -64,11 +50,9 @@ public String getPredictedSequence() {
}

@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
void addMapFields(Map<String, Object> map) {
super.addMapFields(map);
map.put(resultsField + "_sequence", predictedSequence);
map.putAll(super.asMap());
return map;
}

@Override
Expand All @@ -77,8 +61,9 @@ public String getWriteableName() {
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return super.toXContent(builder, params).field(resultsField + "_sequence", predictedSequence);
public void doXContentBody(XContentBuilder builder, Params params) throws IOException {
super.doXContentBody(builder, params);
builder.field(resultsField + "_sequence", predictedSequence);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.util.Objects;
import java.util.stream.Collectors;

public class NerResults implements InferenceResults {
public class NerResults extends NlpInferenceResults {

public static final String NAME = "ner_result";
public static final String ENTITY_FIELD = "entities";
Expand All @@ -30,27 +30,28 @@ public class NerResults implements InferenceResults {

private final List<EntityGroup> entityGroups;

public NerResults(String resultsField, String annotatedResult, List<EntityGroup> entityGroups) {
public NerResults(String resultsField, String annotatedResult, List<EntityGroup> entityGroups, boolean isTruncated) {
super(isTruncated);
this.entityGroups = Objects.requireNonNull(entityGroups);
this.resultsField = Objects.requireNonNull(resultsField);
this.annotatedResult = Objects.requireNonNull(annotatedResult);
}

public NerResults(StreamInput in) throws IOException {
super(in);
entityGroups = in.readList(EntityGroup::new);
resultsField = in.readString();
annotatedResult = in.readString();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
void doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.field(resultsField, annotatedResult);
builder.startArray("entities");
for (EntityGroup entity : entityGroups) {
entity.toXContent(builder, params);
}
builder.endArray();
return builder;
}

@Override
Expand All @@ -59,18 +60,16 @@ public String getWriteableName() {
}

@Override
public void writeTo(StreamOutput out) throws IOException {
void doWriteTo(StreamOutput out) throws IOException {
out.writeList(entityGroups);
out.writeString(resultsField);
out.writeString(annotatedResult);
}

@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
void addMapFields(Map<String, Object> map) {
map.put(resultsField, annotatedResult);
map.put(ENTITY_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList()));
return map;
}

@Override
Expand All @@ -95,15 +94,16 @@ public String getAnnotatedResult() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
NerResults that = (NerResults) o;
return Objects.equals(entityGroups, that.entityGroups)
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(annotatedResult, that.annotatedResult);
return Objects.equals(resultsField, that.resultsField)
&& Objects.equals(annotatedResult, that.annotatedResult)
&& Objects.equals(entityGroups, that.entityGroups);
}

@Override
public int hashCode() {
return Objects.hash(entityGroups, resultsField, annotatedResult);
return Objects.hash(super.hashCode(), resultsField, annotatedResult, entityGroups);
}

public static class EntityGroup implements ToXContentObject, Writeable {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class NlpClassificationInferenceResults extends NlpInferenceResults {

public static final String NAME = "nlp_classification";

// Accessed in sub-classes
protected final String resultsField;
private final String classificationLabel;
private final Double predictionProbability;
private final List<TopClassEntry> topClasses;

public NlpClassificationInferenceResults(
String classificationLabel,
List<TopClassEntry> topClasses,
String resultsField,
Double predictionProbability,
boolean isTruncated
) {
super(isTruncated);
this.classificationLabel = Objects.requireNonNull(classificationLabel);
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
this.resultsField = resultsField;
this.predictionProbability = predictionProbability;
}

public NlpClassificationInferenceResults(StreamInput in) throws IOException {
super(in);
this.classificationLabel = in.readString();
this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
this.resultsField = in.readString();
this.predictionProbability = in.readOptionalDouble();
}

public String getClassificationLabel() {
return classificationLabel;
}

public List<TopClassEntry> getTopClasses() {
return topClasses;
}

@Override
public void doWriteTo(StreamOutput out) throws IOException {
out.writeString(classificationLabel);
out.writeCollection(topClasses);
out.writeString(resultsField);
out.writeOptionalDouble(predictionProbability);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
NlpClassificationInferenceResults that = (NlpClassificationInferenceResults) o;
return Objects.equals(resultsField, that.resultsField)
&& Objects.equals(classificationLabel, that.classificationLabel)
&& Objects.equals(predictionProbability, that.predictionProbability)
&& Objects.equals(topClasses, that.topClasses);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), resultsField, classificationLabel, predictionProbability, topClasses);
}

public Double getPredictionProbability() {
return predictionProbability;
}

@Override
public String getResultsField() {
return resultsField;
}

@Override
public Object predictedValue() {
return classificationLabel;
}

@Override
void addMapFields(Map<String, Object> map) {
map.put(resultsField, classificationLabel);
if (topClasses.isEmpty() == false) {
map.put(
NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD,
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList())
);
}
if (predictionProbability != null) {
map.put(PREDICTION_PROBABILITY, predictionProbability);
}
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.field(resultsField, classificationLabel);
if (topClasses.size() > 0) {
builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses);
}
if (predictionProbability != null) {
builder.field(PREDICTION_PROBABILITY, predictionProbability);
}
}
}

0 comments on commit 44dc322

Please sign in to comment.