Skip to content

Commit

Permalink
[ML] Remove uses of ML HLRC classes (#83885)
Browse files Browse the repository at this point in the history
Removes all references to the ML HLRC classes in preparation
for removal of those classes. Mostly this means converting tests 
that use the HLRC to the low level client.
  • Loading branch information
davidkyle committed Feb 14, 2022
1 parent c18022b commit aa8da2f
Show file tree
Hide file tree
Showing 8 changed files with 336 additions and 316 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
*/
package org.elasticsearch.xpack.core.ml.job.results;

import org.elasticsearch.client.ml.job.config.DetectorFunction;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.ml.job.config.DetectorFunction;

import java.io.IOException;
import java.util.ArrayList;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
*/
package org.elasticsearch.xpack.core.ml.job.results;

import org.elasticsearch.client.ml.job.config.DetectorFunction;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentHelper;
Expand All @@ -16,6 +15,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.job.config.DetectorFunction;
import org.elasticsearch.xpack.core.ml.utils.MlStrings;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
*/
package org.elasticsearch.xpack.core.ml.utils;

import org.elasticsearch.client.ml.inference.NamedXContentObject;
import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchModule;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,17 @@

package org.elasticsearch.xpack.deprecation;

import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.WarningsHandler;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.client.ml.job.config.DataDescription;
import org.elasticsearch.client.ml.job.config.Detector;
import org.elasticsearch.client.ml.job.config.Job;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentType;
import org.junit.After;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand All @@ -44,12 +33,6 @@ public class MlDeprecationIT extends ESRestTestCase {
.setWarningsHandler(WarningsHandler.PERMISSIVE)
.build();

private static class HLRC extends RestHighLevelClient {
HLRC(RestClient restClient) {
super(restClient, RestClient::close, new ArrayList<>());
}
}

@After
public void resetFeatures() throws IOException {
Response response = adminClient().performRequest(new Request("POST", "/_features/_reset"));
Expand All @@ -69,32 +52,21 @@ protected boolean enableWarningsCheck() {

@SuppressWarnings("unchecked")
public void testMlDeprecationChecks() throws Exception {
HLRC hlrc = new HLRC(client());
String jobId = "deprecation_check_job";
hlrc.machineLearning()
.putJob(
new PutJobRequest(
Job.builder(jobId)
.setAnalysisConfig(
AnalysisConfig.builder(Collections.singletonList(Detector.builder().setFunction("count").build()))
)
.setDataDescription(new DataDescription.Builder().setTimeField("time"))
.build()
),
REQUEST_OPTIONS
);

IndexRequest indexRequest = new IndexRequest(".ml-anomalies-.write-" + jobId).id(jobId + "_model_snapshot_1")
.source("{\"job_id\":\"deprecation_check_job\",\"snapshot_id\":\"1\", \"snapshot_doc_count\":1}", XContentType.JSON);
hlrc.index(indexRequest, REQUEST_OPTIONS);

indexRequest = new IndexRequest(".ml-anomalies-.write-" + jobId).id(jobId + "_model_snapshot_2")
.source(
"{\"job_id\":\"deprecation_check_job\",\"snapshot_id\":\"2\",\"snapshot_doc_count\":1,\"min_version\":\"8.0.0\"}",
XContentType.JSON
);
hlrc.index(indexRequest, REQUEST_OPTIONS);
hlrc.indices().refresh(new RefreshRequest(".ml-anomalies-*"), REQUEST_OPTIONS);
buildAndPutJob(jobId);

indexDoc(
".ml-anomalies-.write-" + jobId,
jobId + "_model_snapshot_1",
"{\"job_id\":\"deprecation_check_job\",\"snapshot_id\":\"1\", \"snapshot_doc_count\":1}"
);

indexDoc(
".ml-anomalies-.write-" + jobId,
jobId + "_model_snapshot_2",
"{\"job_id\":\"deprecation_check_job\",\"snapshot_id\":\"2\",\"snapshot_doc_count\":1,\"min_version\":\"8.0.0\"}"
);
client().performRequest(new Request("POST", "/.ml-anomalies-*/_refresh"));

// specify an index so that deprecation checks don't run against any accidentally existing indices
Request getDeprecations = new Request("GET", "/does-not-exist-*/_migration/deprecations");
Expand All @@ -108,4 +80,30 @@ public void testMlDeprecationChecks() throws Exception {
assertThat(mlSettingsDeprecations.get(0).get("_meta"), equalTo(Map.of("job_id", jobId, "snapshot_id", "1")));
}

private Response buildAndPutJob(String jobId) throws Exception {
String jobConfig = """
{
"analysis_config" : {
"bucket_span": "3600s",
"detectors" :[{"function":"count"}]
},
"data_description" : {
"time_field":"time",
"time_format":"yyyy-MM-dd HH:mm:ssX"
}
}""";

Request request = new Request("PUT", "/_ml/anomaly_detectors/" + jobId);
request.setOptions(REQUEST_OPTIONS);
request.setJsonEntity(jobConfig);
return client().performRequest(request);
}

private Response indexDoc(String index, String docId, String source) throws IOException {
Request request = new Request("PUT", "/" + index + "/_doc/" + docId);
request.setOptions(REQUEST_OPTIONS);
request.setJsonEntity(source);
return client().performRequest(request);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.client.ml.GetTrainedModelsStatsResponse;
import org.elasticsearch.client.ml.inference.TrainedModelStats;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.test.ExternalTestCluster;
Expand All @@ -24,9 +23,7 @@
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinitionTests;
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
Expand All @@ -36,6 +33,7 @@

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -124,8 +122,8 @@ public void testPathologicalPipelineCreationAndDeletion() throws Exception {
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
assertBusy(() -> {
try {
assertStatsWithCacheMisses(classificationModelId, 10L);
assertStatsWithCacheMisses(regressionModelId, 10L);
assertStatsWithCacheMisses(classificationModelId, 10);
assertStatsWithCacheMisses(regressionModelId, 10);
} catch (ResponseException ex) {
// this could just mean shard failures.
fail(ex.getMessage());
Expand Down Expand Up @@ -176,15 +174,16 @@ public void testPipelineIngest() throws Exception {

assertBusy(() -> {
try {
assertStatsWithCacheMisses(classificationModelId, 10L);
assertStatsWithCacheMisses(regressionModelId, 15L);
assertStatsWithCacheMisses(classificationModelId, 10);
assertStatsWithCacheMisses(regressionModelId, 15);
} catch (ResponseException ex) {
// this could just mean shard failures.
fail(ex.getMessage());
}
}, 30, TimeUnit.SECONDS);
}

@SuppressWarnings("unchecked")
public void testPipelineIngestWithModelAliases() throws Exception {
String regressionModelId = "test_regression_1";
putModel(regressionModelId, REGRESSION_CONFIG);
Expand Down Expand Up @@ -255,34 +254,33 @@ public void testPipelineIngestWithModelAliases() throws Exception {
assertThat(EntityUtils.toString(searchResponse.getEntity()), not(containsString("\"value\":0")));

assertBusy(() -> {
try (
XContentParser parser = createParser(
JsonXContent.jsonXContent,
client().performRequest(new Request("GET", "_ml/trained_models/" + modelAlias + "/_stats")).getEntity().getContent()
)
) {
GetTrainedModelsStatsResponse response = GetTrainedModelsStatsResponse.fromXContent(parser);
assertThat(response.toString(), response.getTrainedModelStats(), hasSize(1));
TrainedModelStats trainedModelStats = response.getTrainedModelStats().get(0);
assertThat(trainedModelStats.getModelId(), equalTo(regressionModelId2));
assertThat(trainedModelStats.getInferenceStats(), is(notNullValue()));
try {
Response response = client().performRequest(new Request("GET", "_ml/trained_models/" + modelAlias + "/_stats"));
var responseMap = entityAsMap(response);
assertThat((List<?>) responseMap.get("trained_model_stats"), hasSize(1));
var stats = ((List<Map<String, Object>>) responseMap.get("trained_model_stats")).get(0);
assertThat(stats.get("model_id"), equalTo(regressionModelId2));
assertThat(stats.get("inference_stats"), is(notNullValue()));
} catch (ResponseException ex) {
// this could just mean shard failures.
fail(ex.getMessage());
}
});
}

public void assertStatsWithCacheMisses(String modelId, long inferenceCount) throws IOException {
@SuppressWarnings("unchecked")
public void assertStatsWithCacheMisses(String modelId, int inferenceCount) throws IOException {
Response statsResponse = client().performRequest(new Request("GET", "_ml/trained_models/" + modelId + "/_stats"));
try (XContentParser parser = createParser(JsonXContent.jsonXContent, statsResponse.getEntity().getContent())) {
GetTrainedModelsStatsResponse response = GetTrainedModelsStatsResponse.fromXContent(parser);
assertThat(response.getTrainedModelStats(), hasSize(1));
TrainedModelStats trainedModelStats = response.getTrainedModelStats().get(0);
assertThat(trainedModelStats.getInferenceStats(), is(notNullValue()));
assertThat(trainedModelStats.getInferenceStats().getInferenceCount(), equalTo(inferenceCount));
assertThat(trainedModelStats.getInferenceStats().getCacheMissCount(), greaterThan(0L));
}
var responseMap = entityAsMap(statsResponse);
assertThat((List<?>) responseMap.get("trained_model_stats"), hasSize(1));
var stats = ((List<Map<String, Object>>) responseMap.get("trained_model_stats")).get(0);
assertThat(stats.get("inference_stats"), is(notNullValue()));
assertThat(
stats.toString(),
(Integer) XContentMapValues.extractValue("inference_stats.inference_count", stats),
equalTo(inferenceCount)
);
assertThat(stats.toString(), (Integer) XContentMapValues.extractValue("inference_stats.cache_miss_count", stats), greaterThan(0));
}

public void testSimulate() throws IOException {
Expand Down

0 comments on commit aa8da2f

Please sign in to comment.