Skip to content

Commit

Permalink
[LTR] Enable by default on stateful only (#103333)
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed Dec 14, 2023
1 parent d0ef7b6 commit d28481a
Show file tree
Hide file tree
Showing 23 changed files with 68 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
public enum FeatureFlag {
TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null),
LEARNING_TO_RANK("es.learning_to_rank_feature_flag_enabled=true", Version.fromString("8.12.0"), null),
FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null);

public final String systemProperty;
Expand Down
2 changes: 0 additions & 2 deletions x-pack/plugin/ml/qa/basic-multi-node/build.gradle
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import org.elasticsearch.gradle.Version
import org.elasticsearch.gradle.internal.info.BuildParams

apply plugin: 'elasticsearch.legacy-java-rest-test'
Expand All @@ -17,7 +16,6 @@ testClusters.configureEach {
setting 'xpack.license.self_generated.type', 'trial'
setting 'indices.lifecycle.history_index_enabled', 'false'
setting 'slm.history_index_enabled', 'false'
requiresFeature 'es.learning_to_rank_feature_flag_enabled', Version.fromString("8.12.0")
}

if (BuildParams.inFipsJvm){
Expand Down
2 changes: 0 additions & 2 deletions x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import org.elasticsearch.gradle.Version
apply plugin: 'elasticsearch.legacy-yaml-rest-test'

dependencies {
Expand Down Expand Up @@ -258,5 +257,4 @@ testClusters.configureEach {
user username: "no_ml", password: "x-pack-test-password", role: "minimal"
setting 'xpack.license.self_generated.type', 'trial'
setting 'xpack.security.enabled', 'true'
requiresFeature 'es.learning_to_rank_feature_flag_enabled', Version.fromString("8.12.0")
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ public boolean isNlpEnabled() {
return true;
}

@Override
public boolean isLearningToRankEnabled() {
return true;
}

@Override
public String[] getAnalyticsDestIndexAllowedSettings() {
return ANALYTICS_DEST_INDEX_ALLOWED_SETTINGS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankRescorerBuilder;
import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankRescorerFeature;
import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankService;
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
Expand Down Expand Up @@ -886,7 +885,7 @@ private static void reportClashingNodeAttribute(String attrName) {

@Override
public List<RescorerSpec<?>> getRescorers() {
if (enabled && LearningToRankRescorerFeature.isEnabled()) {
if (enabled && machineLearningExtension.get().isLearningToRankEnabled()) {
return List.of(
new RescorerSpec<>(
LearningToRankRescorerBuilder.NAME,
Expand Down Expand Up @@ -1797,7 +1796,7 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {
);
namedXContent.addAll(new CorrelationNamedContentProvider().getNamedXContentParsers());
// LTR Combine with Inference named content provider when feature flag is removed
if (LearningToRankRescorerFeature.isEnabled()) {
if (machineLearningExtension.get().isLearningToRankEnabled()) {
namedXContent.addAll(new MlLTRNamedXContentProvider().getNamedXContentParsers());
}
return namedXContent;
Expand Down Expand Up @@ -1885,7 +1884,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.addAll(new CorrelationNamedContentProvider().getNamedWriteables());
namedWriteables.addAll(new ChangePointNamedContentProvider().getNamedWriteables());
// LTR Combine with Inference named content provider when feature flag is removed
if (LearningToRankRescorerFeature.isEnabled()) {
if (machineLearningExtension.get().isLearningToRankEnabled()) {
namedWriteables.addAll(new MlLTRNamedXContentProvider().getNamedWriteables());
}
return namedWriteables;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ default void configure(Settings settings) {}

boolean isNlpEnabled();

default boolean isLearningToRankEnabled() {
return false;
}

String[] getAnalyticsDestIndexAllowedSettings();

AbstractNodeAvailabilityZoneMapper getNodeAvailabilityZoneMapper(Settings settings, ClusterSettings clusterSettings);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@

public class LocalStateMachineLearningAdOnly extends LocalStateMachineLearning {
public LocalStateMachineLearningAdOnly(final Settings settings, final Path configPath) {
super(settings, configPath, new MlTestExtensionLoader(new MlTestExtension(true, true, true, false, false)));
super(settings, configPath, new MlTestExtensionLoader(new MlTestExtension(true, true, true, false, false, false)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@

public class LocalStateMachineLearningDfaOnly extends LocalStateMachineLearning {
public LocalStateMachineLearningDfaOnly(final Settings settings, final Path configPath) {
super(settings, configPath, new MlTestExtensionLoader(new MlTestExtension(true, true, false, true, false)));
super(settings, configPath, new MlTestExtensionLoader(new MlTestExtension(true, true, false, true, false, false)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@

public class LocalStateMachineLearningNlpOnly extends LocalStateMachineLearning {
public LocalStateMachineLearningNlpOnly(final Settings settings, final Path configPath) {
super(settings, configPath, new MlTestExtensionLoader(new MlTestExtension(true, true, false, false, true)));
super(settings, configPath, new MlTestExtensionLoader(new MlTestExtension(true, true, false, false, true, false)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,14 @@ private MachineLearningUsageTransportAction newUsageAction(
licenseState,
jobManagerHolder,
new MachineLearningExtensionHolder(
new MachineLearningTests.MlTestExtension(true, true, isAnomalyDetectionEnabled, isDataFrameAnalyticsEnabled, isNlpEnabled)
new MachineLearningTests.MlTestExtension(
true,
true,
isAnomalyDetectionEnabled,
isDataFrameAnalyticsEnabled,
isNlpEnabled,
true
)
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,8 @@ public void testNoAttributes_givenClash() throws IOException {

public void testAnomalyDetectionOnly() throws IOException {
Settings settings = Settings.builder().put("path.home", createTempDir()).build();
try (MachineLearning machineLearning = createTrialLicensedMachineLearning(settings)) {
MlTestExtensionLoader loader = new MlTestExtensionLoader(new MlTestExtension(false, false, true, false, false));
machineLearning.loadExtensions(loader);
MlTestExtensionLoader loader = new MlTestExtensionLoader(new MlTestExtension(false, false, true, false, false, false));
try (MachineLearning machineLearning = createTrialLicensedMachineLearning(settings, loader)) {
List<RestHandler> restHandlers = machineLearning.getRestHandlers(settings, null, null, null, null, null, null);
assertThat(restHandlers, hasItem(instanceOf(RestMlInfoAction.class)));
assertThat(restHandlers, hasItem(instanceOf(RestGetJobsAction.class)));
Expand All @@ -242,9 +241,8 @@ public void testAnomalyDetectionOnly() throws IOException {

public void testDataFrameAnalyticsOnly() throws IOException {
Settings settings = Settings.builder().put("path.home", createTempDir()).build();
try (MachineLearning machineLearning = createTrialLicensedMachineLearning(settings)) {
MlTestExtensionLoader loader = new MlTestExtensionLoader(new MlTestExtension(false, false, false, true, false));
machineLearning.loadExtensions(loader);
MlTestExtensionLoader loader = new MlTestExtensionLoader(new MlTestExtension(false, false, false, true, false, false));
try (MachineLearning machineLearning = createTrialLicensedMachineLearning(settings, loader)) {
List<RestHandler> restHandlers = machineLearning.getRestHandlers(settings, null, null, null, null, null, null);
assertThat(restHandlers, hasItem(instanceOf(RestMlInfoAction.class)));
assertThat(restHandlers, not(hasItem(instanceOf(RestGetJobsAction.class))));
Expand All @@ -263,9 +261,8 @@ public void testDataFrameAnalyticsOnly() throws IOException {

public void testNlpOnly() throws IOException {
Settings settings = Settings.builder().put("path.home", createTempDir()).build();
try (MachineLearning machineLearning = createTrialLicensedMachineLearning(settings)) {
MlTestExtensionLoader loader = new MlTestExtensionLoader(new MlTestExtension(false, false, false, false, true));
machineLearning.loadExtensions(loader);
MlTestExtensionLoader loader = new MlTestExtensionLoader(new MlTestExtension(false, false, false, false, true, false));
try (MachineLearning machineLearning = createTrialLicensedMachineLearning(settings, loader)) {
List<RestHandler> restHandlers = machineLearning.getRestHandlers(settings, null, null, null, null, null, null);
assertThat(restHandlers, hasItem(instanceOf(RestMlInfoAction.class)));
assertThat(restHandlers, not(hasItem(instanceOf(RestGetJobsAction.class))));
Expand All @@ -291,19 +288,22 @@ public static class MlTestExtension implements MachineLearningExtension {
private final boolean isAnomalyDetectionEnabled;
private final boolean isDataFrameAnalyticsEnabled;
private final boolean isNlpEnabled;
private final boolean isLearningToRankEnabled;

MlTestExtension(
boolean useIlm,
boolean includeNodeInfo,
boolean isAnomalyDetectionEnabled,
boolean isDataFrameAnalyticsEnabled,
boolean isNlpEnabled
boolean isNlpEnabled,
boolean isLearningToRankEnabled
) {
this.useIlm = useIlm;
this.includeNodeInfo = includeNodeInfo;
this.isAnomalyDetectionEnabled = isAnomalyDetectionEnabled;
this.isDataFrameAnalyticsEnabled = isDataFrameAnalyticsEnabled;
this.isNlpEnabled = isNlpEnabled;
this.isLearningToRankEnabled = isLearningToRankEnabled;
}

@Override
Expand Down Expand Up @@ -331,6 +331,11 @@ public boolean isNlpEnabled() {
return isNlpEnabled;
}

@Override
public boolean isLearningToRankEnabled() {
return isLearningToRankEnabled;
}

@Override
public String[] getAnalyticsDestIndexAllowedSettings() {
return ANALYTICS_DEST_INDEX_ALLOWED_SETTINGS;
Expand Down Expand Up @@ -377,6 +382,12 @@ protected XPackLicenseState getLicenseState() {
}

public static MachineLearning createTrialLicensedMachineLearning(Settings settings) {
return new TrialLicensedMachineLearning(settings);
return createTrialLicensedMachineLearning(settings, null);
}

public static MachineLearning createTrialLicensedMachineLearning(Settings settings, MlTestExtensionLoader loader) {
MachineLearning mlPlugin = new TrialLicensedMachineLearning(settings);
mlPlugin.loadExtensions(loader);
return mlPlugin;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.search.aggregations.metrics.Min;
import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.MachineLearningTests;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -53,7 +54,7 @@ protected AnalysisModule createAnalysisModule() throws Exception {

@Override
protected List<SearchPlugin> getSearchPlugins() {
return List.of(new MachineLearning(Settings.EMPTY));
return List.of(MachineLearningTests.createTrialLicensedMachineLearning(Settings.EMPTY));
}

private static final String TEXT_FIELD_NAME = "text";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import org.elasticsearch.test.InternalMultiBucketAggregationTestCase;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.MachineLearningTests;
import org.junit.After;
import org.junit.Before;

Expand Down Expand Up @@ -49,7 +49,7 @@ public void destroyHash() {

@Override
protected SearchPlugin registerPlugin() {
return new MachineLearning(Settings.EMPTY);
return MachineLearningTests.createTrialLicensedMachineLearning(Settings.EMPTY);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.elasticsearch.search.aggregations.bucket.filter.InternalFilter;
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.MachineLearningTests;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -40,7 +40,7 @@ public class ChangePointAggregatorTests extends AggregatorTestCase {

@Override
protected List<SearchPlugin> getSearchPlugins() {
return List.of(new MachineLearning(Settings.EMPTY));
return List.of(MachineLearningTests.createTrialLicensedMachineLearning(Settings.EMPTY));
}

private static final DateHistogramInterval INTERVAL = DateHistogramInterval.minutes(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
import org.elasticsearch.search.aggregations.support.ValuesSourceType;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.MachineLearningTests;
import org.elasticsearch.xpack.ml.aggs.frequentitemsets.EclatMapReducer.EclatResult;
import org.elasticsearch.xpack.ml.aggs.frequentitemsets.FrequentItemSetCollector.FrequentItemSet;
import org.elasticsearch.xpack.ml.aggs.frequentitemsets.mr.InternalItemSetMapReduceAggregation;
Expand Down Expand Up @@ -66,7 +66,7 @@ public class FrequentItemSetsAggregatorTests extends AggregatorTestCase {

@Override
protected List<SearchPlugin> getSearchPlugins() {
return List.of(new MachineLearning(Settings.EMPTY));
return List.of(MachineLearningTests.createTrialLicensedMachineLearning(Settings.EMPTY));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.MachineLearningTests;
import org.elasticsearch.xpack.ml.aggs.frequentitemsets.mr.InternalItemSetMapReduceAggregationTests.WordCountMapReducer.WordCounts;
import org.elasticsearch.xpack.ml.aggs.frequentitemsets.mr.ItemSetMapReduceValueSource.Field;
import org.elasticsearch.xpack.ml.aggs.frequentitemsets.mr.ItemSetMapReduceValueSource.ValueFormatter;
Expand Down Expand Up @@ -247,7 +247,7 @@ protected void assertFromXContent(

@Override
protected SearchPlugin registerPlugin() {
return new MachineLearning(Settings.EMPTY);
return MachineLearningTests.createTrialLicensedMachineLearning(Settings.EMPTY);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.search.aggregations.bucket.terms.heuristic.SignificanceHeuristic;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.MachineLearningTests;

import java.util.List;
import java.util.function.Function;
Expand Down Expand Up @@ -56,16 +57,14 @@ public void testAssertions() {

@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(
new SearchModule(Settings.EMPTY, List.of(new MachineLearning(Settings.EMPTY))).getNamedXContents()
);
MachineLearning mlPlugin = MachineLearningTests.createTrialLicensedMachineLearning(Settings.EMPTY);
return new NamedXContentRegistry(new SearchModule(Settings.EMPTY, List.of(mlPlugin)).getNamedXContents());
}

@Override
protected NamedWriteableRegistry writableRegistry() {
return new NamedWriteableRegistry(
new SearchModule(Settings.EMPTY, List.of(new MachineLearning(Settings.EMPTY))).getNamedWriteables()
);
MachineLearning mlPlugin = MachineLearningTests.createTrialLicensedMachineLearning(Settings.EMPTY);
return new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, List.of(mlPlugin)).getNamedWriteables());
}

public void testPValueScore_WhenAllDocsContainTerm() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.MachineLearningTests;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
Expand All @@ -43,7 +42,7 @@ public class InferencePipelineAggregationBuilderTests extends BasePipelineAggreg

@Override
protected List<SearchPlugin> plugins() {
return Collections.singletonList(new MachineLearning(Settings.EMPTY));
return List.of(MachineLearningTests.createTrialLicensedMachineLearning(Settings.EMPTY));
}

@Override
Expand Down

0 comments on commit d28481a

Please sign in to comment.