Skip to content

Commit

Permalink
[ML] Change result format to suit rank feature mapping (#93460)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Feb 7, 2023
1 parent e6bbfa2 commit d52b6f9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ public class SlimResults extends NlpInferenceResults {

public record WeightedToken(int token, float weight) implements Writeable, ToXContentObject {

public static final String TOKEN = "token";
public static final String WEIGHT = "weight";

public WeightedToken(StreamInput in) throws IOException {
this(in.readVInt(), in.readFloat());
}
Expand All @@ -42,14 +39,13 @@ public void writeTo(StreamOutput out) throws IOException {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TOKEN, token);
builder.field(WEIGHT, weight);
builder.field(Integer.toString(token), weight);
builder.endObject();
return builder;
}

public Map<String, Object> asMap() {
return Map.of(TOKEN, token, WEIGHT, weight);
return Map.of(Integer.toString(token), weight);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ protected SlimResults createTestInstance() {

@Override
protected SlimResults mutateInstance(SlimResults instance) {
return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929
return new SlimResults(instance.getResultsField() + "-FOO", instance.getWeightedTokens(), instance.isTruncated() == false);
}

@Override
Expand All @@ -45,8 +45,11 @@ void assertFieldValues(SlimResults createdInstance, IngestDocument document, Str
var originalTokens = createdInstance.getWeightedTokens();
assertEquals(originalTokens.size(), ingestedTokens.size());
for (int i = 0; i < createdInstance.getWeightedTokens().size(); i++) {
assertEquals(originalTokens.get(i).token(), (int) ingestedTokens.get(i).get("token"));
assertEquals(originalTokens.get(i).weight(), (float) ingestedTokens.get(i).get("weight"), 0.0001);
assertEquals(
originalTokens.get(i).weight(),
(float) ingestedTokens.get(i).get(Integer.toString(originalTokens.get(i).token())),
0.0001
);
}
}
}

0 comments on commit d52b6f9

Please sign in to comment.