Skip to content

Commit

Permalink
[ML] Use token strings rather than IDs in text expansion (#94389)
Browse files Browse the repository at this point in the history
Use the token name rather than numerical ID in the text expansion model results
  • Loading branch information
davidkyle committed Mar 28, 2023
1 parent 65492f0 commit 8240641
Show file tree
Hide file tree
Showing 16 changed files with 378 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.core.Strings.format;

Expand Down Expand Up @@ -362,7 +361,14 @@ public static Builder builder() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("inference_results", inferenceResults.stream().map(InferenceResults::asMap).collect(Collectors.toList()));
builder.startArray("inference_results");
for (var inference : inferenceResults) {
// inference results implement ToXContentFragment
builder.startObject();
inference.toXContent(builder, params);
builder.endObject();
}
builder.endArray();
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand All @@ -24,28 +24,26 @@ public class TextExpansionResults extends NlpInferenceResults {

public static final String NAME = "text_expansion_result";

public record WeightedToken(int token, float weight) implements Writeable, ToXContentObject {
public record WeightedToken(String token, float weight) implements Writeable, ToXContentFragment {

public WeightedToken(StreamInput in) throws IOException {
this(in.readVInt(), in.readFloat());
this(in.readString(), in.readFloat());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(token);
out.writeString(token);
out.writeFloat(weight);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(Integer.toString(token), weight);
builder.endObject();
builder.field(token, weight);
return builder;
}

public Map<String, Object> asMap() {
return Map.of(Integer.toString(token), weight);
return Map.of(token, weight);
}

@Override
Expand Down Expand Up @@ -90,11 +88,11 @@ public Object predictedValue() {

@Override
void doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.startArray(resultsField);
builder.startObject(resultsField);
for (var weightedToken : weightedTokens) {
weightedToken.toXContent(builder, params);
}
builder.endArray();
builder.endObject();
}

@Override
Expand All @@ -119,6 +117,6 @@ void doWriteTo(StreamOutput out) throws IOException {

@Override
void addMapFields(Map<String, Object> map) {
map.put(resultsField, weightedTokens.stream().map(WeightedToken::asMap).collect(Collectors.toList()));
map.put(resultsField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
Expand All @@ -26,26 +27,36 @@
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResultsTests;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.Matchers.containsString;

public class InferModelActionResponseTests extends AbstractWireSerializingTestCase<Response> {

private static List<String> INFERENCE_RESULT_TYPES = List.of(
ClassificationInferenceResults.NAME,
RegressionInferenceResults.NAME,
NerResults.NAME,
TextEmbeddingResults.NAME,
PyTorchPassThroughResults.NAME,
FillMaskResults.NAME,
WarningInferenceResults.NAME,
QuestionAnsweringInferenceResults.NAME,
TextExpansionResults.NAME
);

@Override
protected Response createTestInstance() {
String resultType = randomFrom(
ClassificationInferenceResults.NAME,
RegressionInferenceResults.NAME,
NerResults.NAME,
TextEmbeddingResults.NAME,
PyTorchPassThroughResults.NAME,
FillMaskResults.NAME,
WarningInferenceResults.NAME,
QuestionAnsweringInferenceResults.NAME
);
String resultType = randomFrom(INFERENCE_RESULT_TYPES);

return new Response(
Stream.generate(() -> randomInferenceResult(resultType)).limit(randomIntBetween(0, 10)).collect(Collectors.toList()),
randomAlphaOfLength(10),
Expand All @@ -55,30 +66,37 @@ protected Response createTestInstance() {

@Override
protected Response mutateInstance(Response instance) {
return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929
var modelId = instance.getModelId();
var isLicensed = instance.isLicensed();
if (randomBoolean()) {
modelId = modelId + "foo";
} else {
isLicensed = isLicensed == false;
}
return new Response(instance.getInferenceResults(), modelId, isLicensed);
}

private static InferenceResults randomInferenceResult(String resultType) {
switch (resultType) {
case ClassificationInferenceResults.NAME:
return ClassificationInferenceResultsTests.createRandomResults();
case RegressionInferenceResults.NAME:
return RegressionInferenceResultsTests.createRandomResults();
case NerResults.NAME:
return NerResultsTests.createRandomResults();
case TextEmbeddingResults.NAME:
return TextEmbeddingResultsTests.createRandomResults();
case PyTorchPassThroughResults.NAME:
return PyTorchPassThroughResultsTests.createRandomResults();
case FillMaskResults.NAME:
return FillMaskResultsTests.createRandomResults();
case WarningInferenceResults.NAME:
return WarningInferenceResultsTests.createRandomResults();
case QuestionAnsweringInferenceResults.NAME:
return QuestionAnsweringInferenceResultsTests.createRandomResults();
default:
fail("unexpected result type [" + resultType + "]");
return null;
return switch (resultType) {
case ClassificationInferenceResults.NAME -> ClassificationInferenceResultsTests.createRandomResults();
case RegressionInferenceResults.NAME -> RegressionInferenceResultsTests.createRandomResults();
case NerResults.NAME -> NerResultsTests.createRandomResults();
case TextEmbeddingResults.NAME -> TextEmbeddingResultsTests.createRandomResults();
case PyTorchPassThroughResults.NAME -> PyTorchPassThroughResultsTests.createRandomResults();
case FillMaskResults.NAME -> FillMaskResultsTests.createRandomResults();
case WarningInferenceResults.NAME -> WarningInferenceResultsTests.createRandomResults();
case QuestionAnsweringInferenceResults.NAME -> QuestionAnsweringInferenceResultsTests.createRandomResults();
case TextExpansionResults.NAME -> TextExpansionResultsTests.createRandomResults();
default -> throw new AssertionError("unexpected result type [" + resultType + "]");
};
}

public void testToXContentString() {
// assert that the toXContent method does not error
for (var inferenceType : INFERENCE_RESULT_TYPES) {
var s = Strings.toString(randomInferenceResult(inferenceType));
assertNotNull(s);
assertThat(s, not(containsString("error")));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,26 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class TextExpansionResultsTests extends InferenceResultsTestCase<TextExpansionResults> {
public static TextExpansionResults createRandomResults() {
int numTokens = randomIntBetween(0, 20);
List<TextExpansionResults.WeightedToken> tokenList = new ArrayList<>();
for (int i = 0; i < numTokens; i++) {
tokenList.add(new TextExpansionResults.WeightedToken(Integer.toString(i), (float) randomDoubleBetween(0.0, 5.0, false)));
}
return new TextExpansionResults(randomAlphaOfLength(4), tokenList, randomBoolean());
}

@Override
protected Writeable.Reader<TextExpansionResults> instanceReader() {
return TextExpansionResults::new;
}

@Override
protected TextExpansionResults createTestInstance() {
int numTokens = randomIntBetween(0, 20);
List<TextExpansionResults.WeightedToken> tokenList = new ArrayList<>();
for (int i = 0; i < numTokens; i++) {
tokenList.add(new TextExpansionResults.WeightedToken(i, (float) randomDoubleBetween(0.0, 5.0, false)));
}
return new TextExpansionResults(randomAlphaOfLength(4), tokenList, randomBoolean());
return createRandomResults();
}

@Override
Expand All @@ -38,18 +43,15 @@ protected TextExpansionResults mutateInstance(TextExpansionResults instance) {
@Override
@SuppressWarnings("unchecked")
void assertFieldValues(TextExpansionResults createdInstance, IngestDocument document, String resultsField) {
var ingestedTokens = (List<Map<String, Object>>) document.getFieldValue(
var ingestedTokens = (Map<String, Object>) document.getFieldValue(
resultsField + '.' + createdInstance.getResultsField(),
List.class
Map.class
);
var originalTokens = createdInstance.getWeightedTokens();
assertEquals(originalTokens.size(), ingestedTokens.size());
for (int i = 0; i < createdInstance.getWeightedTokens().size(); i++) {
assertEquals(
originalTokens.get(i).weight(),
(float) ingestedTokens.get(i).get(Integer.toString(originalTokens.get(i).token())),
0.0001
);
}
var tokenMap = createdInstance.getWeightedTokens()
.stream()
.collect(Collectors.toMap(TextExpansionResults.WeightedToken::token, TextExpansionResults.WeightedToken::weight));
assertEquals(tokenMap.size(), ingestedTokens.size());

assertEquals(tokenMap, ingestedTokens);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public void setLogging() throws IOException {
Request loggingSettings = new Request("PUT", "_cluster/settings");
loggingSettings.setJsonEntity("""
{"persistent" : {
"logger.org.elasticsearch.xpack.ml.inference.assignment" : "TRACE",
"logger.org.elasticsearch.xpack.ml.inference.assignment" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.inference.deployment" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.inference.pytorch" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.process.logging" : "DEBUG"
Expand Down

0 comments on commit 8240641

Please sign in to comment.