Skip to content

Commit

Permalink
[ML] Provide a way to revert an AD job to an empty snapshot (#65431)
Browse files Browse the repository at this point in the history
This commit adds a way to revert an anomaly detection job to
the empty snapshot. Combining this with `delete_intervening_results`
to `true`, the user can now reset the job and start over.

The API call looks like this:

POST _ml/anomaly_detectors/<job_id>/model_snapshots/empty/_revert
  • Loading branch information
dimitris-athanasiou committed Nov 24, 2020
1 parent 699af9d commit 350c2b9
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand Down Expand Up @@ -469,7 +470,7 @@ public Job mergeWithJob(Job source, ByteSizeValue maxModelMemoryLimit) {
builder.setCustomSettings(customSettings);
}
if (modelSnapshotId != null) {
builder.setModelSnapshotId(modelSnapshotId);
builder.setModelSnapshotId(ModelSnapshot.isTheEmptySnapshot(modelSnapshotId) ? null : modelSnapshotId);
}
if (modelSnapshotMinVersion != null) {
builder.setModelSnapshotMinVersion(modelSnapshotMinVersion);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ private static ObjectParser<Builder, Void> createParser(boolean ignoreUnknownFie
return parser;
}

private static String EMPTY_SNAPSHOT_ID = "empty";

private final String jobId;

Expand Down Expand Up @@ -285,6 +286,14 @@ public List<String> stateDocumentIds() {
return stateDocumentIds;
}

public boolean isTheEmptySnapshot() {
return isTheEmptySnapshot(snapshotId);
}

public static boolean isTheEmptySnapshot(String snapshotId) {
return EMPTY_SNAPSHOT_ID.equals(snapshotId);
}

public static String documentIdPrefix(String jobId) {
return jobId + "_" + TYPE + "_";
}
Expand Down Expand Up @@ -435,4 +444,9 @@ public ModelSnapshot build() {
latestRecordTimeStamp, latestResultTimeStamp, quantiles, retain);
}
}

public static ModelSnapshot emptySnapshot(String jobId) {
return new ModelSnapshot(jobId, Version.CURRENT, new Date(), "empty snapshot", EMPTY_SNAPSHOT_ID, 0,
new ModelSizeStats.Builder(jobId).build(), null, null, null, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.test.VersionUtils;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -26,7 +27,9 @@
import java.util.Set;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;

public class JobUpdateTests extends AbstractSerializingTestCase<JobUpdate> {
Expand Down Expand Up @@ -369,4 +372,23 @@ public void testUpdate_withAnalysisLimitsPreviouslyUndefined() {

updateAboveMaxLimit.mergeWithJob(jobBuilder.build(), new ByteSizeValue(10000L, ByteSizeUnit.MB));
}

public void testUpdate_givenEmptySnapshot() {
Job.Builder jobBuilder = new Job.Builder("my_job");
Detector.Builder d1 = new Detector.Builder("count", null);
AnalysisConfig.Builder ac = new AnalysisConfig.Builder(Collections.singletonList(d1.build()));
jobBuilder.setAnalysisConfig(ac);
jobBuilder.setDataDescription(new DataDescription.Builder());
jobBuilder.setCreateTime(new Date());
jobBuilder.setModelSnapshotId("some_snapshot_id");
Job job = jobBuilder.build();
assertThat(job.getModelSnapshotId(), equalTo("some_snapshot_id"));

JobUpdate update = new JobUpdate.Builder(job.getId())
.setModelSnapshotId(ModelSnapshot.emptySnapshot(job.getId()).getSnapshotId())
.build();

Job updatedJob = update.mergeWithJob(job, ByteSizeValue.ofMb(100));
assertThat(updatedJob.getModelSnapshotId(), is(nullValue()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;

public class ModelSnapshotTests extends AbstractSerializingTestCase<ModelSnapshot> {
private static final Date DEFAULT_TIMESTAMP = new Date();
Expand Down Expand Up @@ -155,7 +157,7 @@ public static ModelSnapshot createRandomized() {
modelSnapshot.setMinVersion(Version.CURRENT);
modelSnapshot.setTimestamp(new Date(TimeValue.parseTimeValue(randomTimeValue(), "test").millis()));
modelSnapshot.setDescription(randomAlphaOfLengthBetween(1, 20));
modelSnapshot.setSnapshotId(randomAlphaOfLengthBetween(1, 20));
modelSnapshot.setSnapshotId(randomAlphaOfLength(10));
modelSnapshot.setSnapshotDocCount(randomInt());
modelSnapshot.setModelSizeStats(ModelSizeStatsTests.createRandomized());
modelSnapshot.setLatestResultTimeStamp(
Expand Down Expand Up @@ -214,4 +216,18 @@ public void testLenientParser() throws IOException {
ModelSnapshot.LENIENT_PARSER.apply(parser, null);
}
}

public void testEmptySnapshot() {
ModelSnapshot modelSnapshot = ModelSnapshot.emptySnapshot("my_job");
assertThat(modelSnapshot.getSnapshotId(), equalTo("empty"));
assertThat(modelSnapshot.isTheEmptySnapshot(), is(true));
assertThat(modelSnapshot.getMinVersion(), equalTo(Version.CURRENT));
assertThat(modelSnapshot.getLatestRecordTimeStamp(), is(nullValue()));
assertThat(modelSnapshot.getLatestResultTimeStamp(), is(nullValue()));
}

public void testIsEmpty_GivenNonEmptySnapshot() {
ModelSnapshot modelSnapshot = createRandomized();
assertThat(modelSnapshot.isTheEmptySnapshot(), is(false));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.annotations.Annotation;
import org.elasticsearch.xpack.core.ml.annotations.Annotation.Event;
import org.elasticsearch.xpack.core.ml.annotations.AnnotationIndex;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
import org.elasticsearch.xpack.core.ml.job.config.Detector;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.Quantiles;
import org.elasticsearch.xpack.core.ml.job.results.AnomalyRecord;
import org.elasticsearch.xpack.core.ml.job.results.Bucket;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.junit.After;
Expand All @@ -49,6 +52,8 @@
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;

/**
* This test pushes data through a job in 2 runs creating
Expand All @@ -58,19 +63,60 @@
public class RevertModelSnapshotIT extends MlNativeAutodetectIntegTestCase {

@After
public void tearDownData() throws Exception {
public void tearDownData() {
cleanUp();
}

public void testRevertModelSnapshot() throws Exception {
test("revert-model-snapshot-it-job", false);
testRunJobInTwoPartsAndRevertSnapshotAndRunToCompletion("revert-model-snapshot-it-job", false);
}

public void testRevertModelSnapshot_DeleteInterveningResults() throws Exception {
test("revert-model-snapshot-it-job-delete-intervening-results", true);
testRunJobInTwoPartsAndRevertSnapshotAndRunToCompletion("revert-model-snapshot-it-job-delete-intervening-results", true);
}

private void test(String jobId, boolean deleteInterveningResults) throws Exception {
public void testRevertToEmptySnapshot() throws Exception {
String jobId = "revert-to-empty-snapshot-test";

TimeValue bucketSpan = TimeValue.timeValueHours(1);
long startTime = 1491004800000L;

String data = generateData(startTime, bucketSpan, 20, Arrays.asList("foo"),
(bucketIndex, series) -> bucketIndex == 19 ? 100.0 : 10.0).stream().collect(Collectors.joining());

Job.Builder job = buildAndRegisterJob(jobId, bucketSpan);
openJob(job.getId());
postData(job.getId(), data);
flushJob(job.getId(), true);
closeJob(job.getId());

assertThat(getJob(jobId).get(0).getModelSnapshotId(), is(notNullValue()));
List<Bucket> expectedBuckets = getBuckets(jobId);
assertThat(expectedBuckets.size(), equalTo(20));
List<AnomalyRecord> expectedRecords = getRecords(jobId);
assertThat(expectedBuckets.isEmpty(), is(false));
assertThat(expectedRecords.isEmpty(), is(false));

RevertModelSnapshotAction.Response revertResponse = revertModelSnapshot(jobId, "empty", true);
assertThat(revertResponse.getModel().getSnapshotId(), equalTo("empty"));

assertThat(getJob(jobId).get(0).getModelSnapshotId(), is(nullValue()));
assertThat(getBuckets(jobId).isEmpty(), is(true));
assertThat(getRecords(jobId).isEmpty(), is(true));
assertThat(getJobStats(jobId).get(0).getDataCounts().getLatestRecordTimeStamp(), is(nullValue()));

// Now run again and see we get same results
openJob(job.getId());
DataCounts dataCounts = postData(job.getId(), data);
assertThat(dataCounts.getOutOfOrderTimeStampCount(), equalTo(0L));
flushJob(job.getId(), true);
closeJob(job.getId());

assertThat(getBuckets(jobId).size(), equalTo(expectedBuckets.size()));
assertThat(getRecords(jobId), equalTo(expectedRecords));
}

private void testRunJobInTwoPartsAndRevertSnapshotAndRunToCompletion(String jobId, boolean deleteInterveningResults) throws Exception {
TimeValue bucketSpan = TimeValue.timeValueHours(1);
long startTime = 1491004800000L;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ private Optional<Quantiles> getQuantiles() throws Exception {
AtomicReference<Exception> errorHolder = new AtomicReference<>();
AtomicReference<Optional<Quantiles>> resultHolder = new AtomicReference<>();
CountDownLatch latch = new CountDownLatch(1);
jobResultsProvider.getAutodetectParams(JobTests.buildJobBuilder(JOB_ID).build(), params -> {
jobResultsProvider.getAutodetectParams(JobTests.buildJobBuilder(JOB_ID).setModelSnapshotId("test_snapshot").build(), params -> {
resultHolder.set(Optional.ofNullable(params.quantiles()));
latch.countDown();
}, e -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,31 @@ private void getModelSnapshot(RevertModelSnapshotAction.Request request, JobResu
Consumer<Exception> errorHandler) {
logger.info("Reverting to snapshot '" + request.getSnapshotId() + "'");

if (ModelSnapshot.isTheEmptySnapshot(request.getSnapshotId())) {
handler.accept(ModelSnapshot.emptySnapshot(request.getJobId()));
return;
}

provider.getModelSnapshot(request.getJobId(), request.getSnapshotId(), modelSnapshot -> {
if (modelSnapshot == null) {
throw new ResourceNotFoundException(Messages.getMessage(Messages.REST_NO_SUCH_MODEL_SNAPSHOT, request.getSnapshotId(),
request.getJobId()));
throw missingSnapshotException(request);
}
handler.accept(modelSnapshot.result);
}, errorHandler);
}

private static ResourceNotFoundException missingSnapshotException(RevertModelSnapshotAction.Request request) {
return new ResourceNotFoundException(Messages.getMessage(Messages.REST_NO_SUCH_MODEL_SNAPSHOT, request.getSnapshotId(),
request.getJobId()));
}

private ActionListener<RevertModelSnapshotAction.Response> wrapDeleteOldAnnotationsListener(
ActionListener<RevertModelSnapshotAction.Response> listener,
ModelSnapshot modelSnapshot,
String jobId) {

return ActionListener.wrap(response -> {
Date deleteAfter = modelSnapshot.getLatestResultTimeStamp();
Date deleteAfter = modelSnapshot.getLatestResultTimeStamp() == null ? new Date(0) : modelSnapshot.getLatestResultTimeStamp();
logger.info("[{}] Removing intervening annotations after reverting model: deleting annotations after [{}]", jobId, deleteAfter);

JobDataDeleter dataDeleter = new JobDataDeleter(client, jobId);
Expand Down Expand Up @@ -176,7 +185,7 @@ private ActionListener<RevertModelSnapshotAction.Response> wrapDeleteOldDataList
// wrap the listener with one that invokes the OldDataRemover on
// acknowledged responses
return ActionListener.wrap(response -> {
Date deleteAfter = modelSnapshot.getLatestResultTimeStamp();
Date deleteAfter = modelSnapshot.getLatestResultTimeStamp() == null ? new Date(0) : modelSnapshot.getLatestResultTimeStamp();
logger.info("[{}] Removing intervening records after reverting model: deleting results after [{}]", jobId, deleteAfter);

JobDataDeleter dataDeleter = new JobDataDeleter(client, jobId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ private void validate(Job job, JobUpdate jobUpdate, ActionListener<Void> handler
}

private void validateModelSnapshotIdUpdate(Job job, String modelSnapshotId, VoidChainTaskExecutor voidChainTaskExecutor) {
if (modelSnapshotId != null) {
if (modelSnapshotId != null && ModelSnapshot.isTheEmptySnapshot(modelSnapshotId) == false) {
voidChainTaskExecutor.add(listener -> {
jobResultsProvider.getModelSnapshot(job.getId(), modelSnapshotId, newModelSnapshot -> {
if (newModelSnapshot == null) {
Expand Down Expand Up @@ -599,6 +599,11 @@ public void revertSnapshot(RevertModelSnapshotAction.Request request, ActionList
// Step 3. After the model size stats is persisted, also persist the snapshot's quantiles and respond
// -------
CheckedConsumer<IndexResponse, Exception> modelSizeStatsResponseHandler = response -> {
// In case we are reverting to the empty snapshot the quantiles will be null
if (modelSnapshot.getQuantiles() == null) {
actionListener.onResponse(new RevertModelSnapshotAction.Response(modelSnapshot));
return;
}
jobResultsPersister.persistQuantiles(modelSnapshot.getQuantiles(), WriteRequest.RefreshPolicy.IMMEDIATE,
ActionListener.wrap(quantilesResponse -> {
// The quantiles can be large, and totally dominate the output -
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,11 +585,12 @@ public void getAutodetectParams(Job job, String snapshotId, Consumer<AutodetectP
MultiSearchRequestBuilder msearch = client.prepareMultiSearch()
.add(createLatestDataCountsSearch(resultsIndex, jobId))
.add(createLatestModelSizeStatsSearch(resultsIndex))
.add(createLatestTimingStatsSearch(resultsIndex, jobId))
// These next two document IDs never need to be the legacy ones due to the rule
// that you cannot open a 5.4 job in a subsequent version of the product
.add(createDocIdSearch(resultsIndex, ModelSnapshot.documentId(jobId, snapshotId)))
.add(createDocIdSearch(stateIndex, Quantiles.documentId(jobId)));
.add(createLatestTimingStatsSearch(resultsIndex, jobId));

if (snapshotId != null) {
msearch.add(createDocIdSearch(resultsIndex, ModelSnapshot.documentId(jobId, snapshotId)));
msearch.add(createDocIdSearch(stateIndex, Quantiles.documentId(jobId)));
}

for (String filterId : job.getAnalysisConfig().extractReferencedFilters()) {
msearch.add(createDocIdSearch(MlMetaIndex.indexName(), MlFilter.documentId(filterId)));
Expand Down

0 comments on commit 350c2b9

Please sign in to comment.