Skip to content

Commit

Permalink
[LTR] Rescore window size improvements. (#104318)
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed Jan 17, 2024
1 parent 96d83cd commit 0aa194a
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 99 deletions.
Expand Up @@ -73,6 +73,8 @@ public static RescorerBuilder<?> parseFromXContent(XContentParser parser, Consum
RescorerBuilder<?> rescorer = null;
Integer windowSize = null;
XContentParser.Token token;
String rescorerType = null;

while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
fieldName = parser.currentName();
Expand All @@ -83,18 +85,25 @@ public static RescorerBuilder<?> parseFromXContent(XContentParser parser, Consum
throw new ParsingException(parser.getTokenLocation(), "rescore doesn't support [" + fieldName + "]");
}
} else if (token == XContentParser.Token.START_OBJECT) {
rescorer = parser.namedObject(RescorerBuilder.class, fieldName, null);
rescorerNameConsumer.accept(fieldName);
if (fieldName != null) {
rescorer = parser.namedObject(RescorerBuilder.class, fieldName, null);
rescorerNameConsumer.accept(fieldName);
rescorerType = fieldName;
}
} else {
throw new ParsingException(parser.getTokenLocation(), "unexpected token [" + token + "] after [" + fieldName + "]");
}
}
if (rescorer == null) {
throw new ParsingException(parser.getTokenLocation(), "missing rescore type");
}

if (windowSize != null) {
rescorer.windowSize(windowSize.intValue());
} else if (rescorer.isWindowSizeRequired()) {
throw new ParsingException(parser.getTokenLocation(), "window_size is required for rescorer of type [" + rescorerType + "]");
}

return rescorer;
}

Expand All @@ -111,11 +120,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

protected abstract void doXContent(XContentBuilder builder, Params params) throws IOException;

/**
* Indicate if the window_size is a required parameter for the rescorer.
*/
protected boolean isWindowSizeRequired() {
return false;
}

/**
* Build the {@linkplain RescoreContext} that will be used to actually
* execute the rescore against a particular shard.
*/
public final RescoreContext buildContext(SearchExecutionContext context) throws IOException {
if (isWindowSizeRequired()) {
assert windowSize != null;
}
int finalWindowSize = windowSize == null ? DEFAULT_WINDOW_SIZE : windowSize;
RescoreContext rescoreContext = innerBuildContext(finalWindowSize, context);
return rescoreContext;
Expand Down
Expand Up @@ -55,6 +55,15 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r
if (ltrRescoreContext.regressionModelDefinition == null) {
throw new IllegalStateException("local model reference is null, missing rewriteAndFetch before rescore phase?");
}

if (rescoreContext.getWindowSize() < topDocs.scoreDocs.length) {
throw new IllegalArgumentException(
"Rescore window is too small and should be at least the value of from + size but was ["
+ rescoreContext.getWindowSize()
+ "]"
);
}

LocalModel definition = ltrRescoreContext.regressionModelDefinition;

// First take top slice of incoming docs, to be rescored:
Expand Down
Expand Up @@ -32,10 +32,10 @@

public class LearningToRankRescorerBuilder extends RescorerBuilder<LearningToRankRescorerBuilder> {

public static final String NAME = "learning_to_rank";
private static final ParseField MODEL_FIELD = new ParseField("model_id");
private static final ParseField PARAMS_FIELD = new ParseField("params");
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME, false, Builder::new);
public static final ParseField NAME = new ParseField("learning_to_rank");
public static final ParseField MODEL_FIELD = new ParseField("model_id");
public static final ParseField PARAMS_FIELD = new ParseField("params");
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), false, Builder::new);

static {
PARSER.declareString(Builder::setModelId, MODEL_FIELD);
Expand Down Expand Up @@ -251,7 +251,7 @@ protected LearningToRankRescorerContext innerBuildContext(int windowSize, Search

@Override
public String getWriteableName() {
return NAME;
return NAME.getPreferredName();
}

@Override
Expand All @@ -260,6 +260,11 @@ public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current();
}

@Override
protected boolean isWindowSizeRequired() {
return true;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
assert localModel == null || rescoreOccurred : "Unnecessarily populated local model object";
Expand All @@ -270,7 +275,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.startObject(NAME.getPreferredName());
builder.field(MODEL_FIELD.getPreferredName(), modelId);
if (this.params != null) {
builder.field(PARAMS_FIELD.getPreferredName(), this.params);
Expand Down
Expand Up @@ -9,14 +9,19 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.rescore.RescorerBuilder;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
Expand All @@ -25,48 +30,36 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import static org.elasticsearch.search.rank.RankBuilder.WINDOW_SIZE_FIELD;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfigTests.randomLearningToRankConfig;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class LearningToRankRescorerBuilderSerializationTests extends AbstractBWCSerializationTestCase<LearningToRankRescorerBuilder> {

private static LearningToRankService learningToRankService = mock(LearningToRankService.class);

@Override
protected LearningToRankRescorerBuilder doParseInstance(XContentParser parser) throws IOException {
String fieldName = null;
LearningToRankRescorerBuilder rescorer = null;
Integer windowSize = null;
XContentParser.Token token = parser.nextToken();
assert token == XContentParser.Token.START_OBJECT;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
fieldName = parser.currentName();
} else if (token.isValue()) {
if (WINDOW_SIZE_FIELD.match(fieldName, parser.getDeprecationHandler())) {
windowSize = parser.intValue();
} else {
throw new ParsingException(parser.getTokenLocation(), "rescore doesn't support [" + fieldName + "]");
public void testRequiredWindowSize() throws IOException {
for (int runs = 0; runs < NUMBER_OF_TEST_RUNS; runs++) {
LearningToRankRescorerBuilder testInstance = createTestInstance();
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
builder.startObject();
testInstance.doXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject();

try (XContentParser parser = JsonXContent.jsonXContent.createParser(parserConfig(), Strings.toString(builder))) {
ParsingException e = expectThrows(ParsingException.class, () -> RescorerBuilder.parseFromXContent(parser, (r) -> {}));
assertThat(e.getMessage(), equalTo("window_size is required for rescorer of type [learning_to_rank]"));
}
} else if (token == XContentParser.Token.START_OBJECT) {
rescorer = LearningToRankRescorerBuilder.fromXContent(parser, learningToRankService);
} else {
throw new ParsingException(parser.getTokenLocation(), "unexpected token [" + token + "] after [" + fieldName + "]");
}
}
if (rescorer == null) {
throw new ParsingException(parser.getTokenLocation(), "missing rescore type");
}
if (windowSize != null) {
rescorer.windowSize(windowSize);
}
return rescorer;
}

@Override
protected LearningToRankRescorerBuilder doParseInstance(XContentParser parser) throws IOException {
return (LearningToRankRescorerBuilder) RescorerBuilder.parseFromXContent(parser, (r) -> {});
}

@Override
Expand All @@ -85,76 +78,49 @@ protected LearningToRankRescorerBuilder createTestInstance() {
learningToRankService
);

if (randomBoolean()) {
builder.windowSize(randomIntBetween(1, 10000));
}
builder.windowSize(randomIntBetween(1, 10000));

return builder;
}

@Override
protected LearningToRankRescorerBuilder createXContextTestInstance(XContentType xContentType) {
return new LearningToRankRescorerBuilder(randomAlphaOfLength(10), randomBoolean() ? randomParams() : null, learningToRankService);
return new LearningToRankRescorerBuilder(randomAlphaOfLength(10), randomBoolean() ? randomParams() : null, learningToRankService)
.windowSize(randomIntBetween(1, 10000));
}

@Override
protected LearningToRankRescorerBuilder mutateInstance(LearningToRankRescorerBuilder instance) throws IOException {

int i = randomInt(4);
return switch (i) {
case 0 -> {
LearningToRankRescorerBuilder builder = new LearningToRankRescorerBuilder(
randomValueOtherThan(instance.modelId(), () -> randomAlphaOfLength(10)),
instance.params(),
learningToRankService
);
if (instance.windowSize() != null) {
builder.windowSize(instance.windowSize());
}
yield builder;
}
case 0 -> new LearningToRankRescorerBuilder(
randomValueOtherThan(instance.modelId(), () -> randomAlphaOfLength(10)),
instance.params(),
learningToRankService
).windowSize(instance.windowSize());
case 1 -> new LearningToRankRescorerBuilder(instance.modelId(), instance.params(), learningToRankService).windowSize(
randomValueOtherThan(instance.windowSize(), () -> randomIntBetween(1, 10000))
);
case 2 -> {
LearningToRankRescorerBuilder builder = new LearningToRankRescorerBuilder(
instance.modelId(),
randomValueOtherThan(instance.params(), () -> (randomBoolean() ? randomParams() : null)),
learningToRankService
);
if (instance.windowSize() != null) {
builder.windowSize(instance.windowSize() + 1);
}
yield builder;
}
case 2 -> new LearningToRankRescorerBuilder(
instance.modelId(),
randomValueOtherThan(instance.params(), () -> (randomBoolean() ? randomParams() : null)),
learningToRankService
).windowSize(instance.windowSize());
case 3 -> {
LearningToRankConfig learningToRankConfig = randomValueOtherThan(
instance.learningToRankConfig(),
() -> randomLearningToRankConfig()
);
LearningToRankRescorerBuilder builder = new LearningToRankRescorerBuilder(
instance.modelId(),
learningToRankConfig,
null,
learningToRankService
yield new LearningToRankRescorerBuilder(instance.modelId(), learningToRankConfig, null, learningToRankService).windowSize(
instance.windowSize()
);
if (instance.windowSize() != null) {
builder.windowSize(instance.windowSize());
}
yield builder;
}
case 4 -> {
LearningToRankRescorerBuilder builder = new LearningToRankRescorerBuilder(
mock(LocalModel.class),
instance.learningToRankConfig(),
instance.params(),
learningToRankService
);
if (instance.windowSize() != null) {
builder.windowSize(instance.windowSize());
}
yield builder;
}
case 4 -> new LearningToRankRescorerBuilder(
mock(LocalModel.class),
instance.learningToRankConfig(),
instance.params(),
learningToRankService
).windowSize(instance.windowSize());
default -> throw new AssertionError("Unexpected random test case");
};
}
Expand All @@ -169,31 +135,38 @@ protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlLTRNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
namedXContent.addAll(new SearchModule(Settings.EMPTY, List.of()).getNamedXContents());
namedXContent.add(
new NamedXContentRegistry.Entry(
RescorerBuilder.class,
LearningToRankRescorerBuilder.NAME,
(p, c) -> LearningToRankRescorerBuilder.fromXContent(p, learningToRankService)
)
);
return new NamedXContentRegistry(namedXContent);
}

@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return writableRegistry();
}

@Override
protected NamedWriteableRegistry writableRegistry() {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables());
namedWriteables.addAll(new MlLTRNamedXContentProvider().getNamedWriteables());
namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables());
namedWriteables.addAll(new SearchModule(Settings.EMPTY, List.of()).getNamedWriteables());
namedWriteables.add(
new NamedWriteableRegistry.Entry(
RescorerBuilder.class,
LearningToRankRescorerBuilder.NAME.getPreferredName(),
in -> new LearningToRankRescorerBuilder(in, learningToRankService)
)
);
return new NamedWriteableRegistry(namedWriteables);
}

@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return writableRegistry();
}

private static Map<String, Object> randomParams() {
return randomMap(1, randomIntBetween(1, 10), () -> new Tuple<>(randomIdentifier(), randomIdentifier()));
}

private static LocalModel localModelMock() {
LocalModel model = mock(LocalModel.class);
String modelId = randomIdentifier();
when(model.getModelId()).thenReturn(modelId);
return model;
}
}

0 comments on commit 0aa194a

Please sign in to comment.