Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/117287.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117287
summary: Fixing bug setting index when parsing Google Vertex AI results
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
public class GoogleVertexAiRerankResponseEntity {

private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Google Vertex AI rerank response";
private static final String INVALID_ID_FIELD_FORMAT_TEMPLATE = "Expected numeric value for record ID field in Google Vertex AI rerank "
+ "response but received [%s]";

/**
* Parses the Google Vertex AI rerank response.
Expand Down Expand Up @@ -109,14 +111,27 @@ private static List<RankedDocsResults.RankedDoc> doParse(XContentParser parser)
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName()));
}

return new RankedDocsResults.RankedDoc(index, parsedRankedDoc.score, parsedRankedDoc.content);
if (parsedRankedDoc.id == null) {
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.ID.getPreferredName()));
}

try {
return new RankedDocsResults.RankedDoc(
Integer.parseInt(parsedRankedDoc.id),
parsedRankedDoc.score,
parsedRankedDoc.content
);
} catch (NumberFormatException e) {
throw new IllegalStateException(format(INVALID_ID_FIELD_FORMAT_TEMPLATE, parsedRankedDoc.id));
}
});
}

private record RankedDoc(@Nullable Float score, @Nullable String content) {
private record RankedDoc(@Nullable Float score, @Nullable String content, @Nullable String id) {

private static final ParseField CONTENT = new ParseField("content");
private static final ParseField SCORE = new ParseField("score");
private static final ParseField ID = new ParseField("id");
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
"google_vertex_ai_rerank_response",
true,
Expand All @@ -126,6 +141,7 @@ private record RankedDoc(@Nullable Float score, @Nullable String content) {
static {
PARSER.declareString(Builder::setContent, CONTENT);
PARSER.declareFloat(Builder::setScore, SCORE);
PARSER.declareString(Builder::setId, ID);
}

public static RankedDoc parse(XContentParser parser) {
Expand All @@ -137,6 +153,7 @@ private static final class Builder {

private String content;
private Float score;
private String id;

private Builder() {}

Expand All @@ -150,8 +167,13 @@ public Builder setContent(String content) {
return this;
}

public Builder setId(String id) {
this.id = id;
return this;
}

public RankedDoc build() {
return new RankedDoc(score, content);
return new RankedDoc(score, content, id);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);

assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(0, 0.97F, "content 2"))));
assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"))));
}

public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
Expand Down Expand Up @@ -68,7 +68,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException

assertThat(
parsedResults.getRankedDocs(),
is(List.of(new RankedDocsResults.RankedDoc(0, 0.97F, "content 2"), new RankedDocsResults.RankedDoc(1, 0.90F, "content 1")))
is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"), new RankedDocsResults.RankedDoc(1, 0.90F, "content 1")))
);
}

Expand Down Expand Up @@ -161,4 +161,37 @@ public void testFromResponse_FailsWhenScoreFieldIsNotPresent() {

assertThat(thrownException.getMessage(), is("Failed to find required field [score] in Google Vertex AI rerank response"));
}

public void testFromResponse_FailsWhenIDFieldIsNotInteger() {
String responseJson = """
{
"records": [
{
"id": "abcd",
"title": "title 2",
"content": "content 2",
"score": 0.97
},
{
"id": "1",
"title": "title 1",
"content": "content 1",
"score": 0.96
}
]
}
""";

var thrownException = expectThrows(
IllegalStateException.class,
() -> GoogleVertexAiRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
)
);

assertThat(
thrownException.getMessage(),
is("Expected numeric value for record ID field in Google Vertex AI rerank response but received [abcd]")
);
}
}