Skip to content

Commit

Permalink
[ML] Make model snapshot upgrade autoscaling friendly (#81303)
Browse files Browse the repository at this point in the history
Model snapshot upgrade was not taking autoscaling into account
when doing node assignment.

Backport of #81123
  • Loading branch information
droberts195 committed Dec 8, 2021
1 parent 8d42b0d commit 7d493f6
Show file tree
Hide file tree
Showing 14 changed files with 507 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,27 @@ public static JobState getJobStateModifiedForReassignments(@Nullable PersistentT
return jobState;
}

public static SnapshotUpgradeState getSnapshotUpgradeState(
String jobId,
String snapshotId,
@Nullable PersistentTasksCustomMetadata tasks
) {
return getSnapshotUpgradeState(getSnapshotUpgraderTask(jobId, snapshotId, tasks));
}

public static SnapshotUpgradeState getSnapshotUpgradeState(@Nullable PersistentTasksCustomMetadata.PersistentTask<?> task) {
if (task == null) {
return SnapshotUpgradeState.STOPPED;
}
SnapshotUpgradeTaskState taskState = (SnapshotUpgradeTaskState) task.getState();
if (taskState == null) {
// If we haven't set a state yet then the task has never been assigned, so
// report that it's doing the first thing it does
return SnapshotUpgradeState.LOADING_OLD_STATE;
}
return taskState.getState();
}

public static DatafeedState getDatafeedState(String datafeedId, @Nullable PersistentTasksCustomMetadata tasks) {
PersistentTasksCustomMetadata.PersistentTask<?> task = getDatafeedTask(datafeedId, tasks);
if (task == null) {
Expand Down Expand Up @@ -406,8 +427,7 @@ public static MemoryTrackedTaskState getMemoryTrackedTaskState(PersistentTasksCu
case JOB_TASK_NAME:
return getJobStateModifiedForReassignments(task);
case JOB_SNAPSHOT_UPGRADE_TASK_NAME:
SnapshotUpgradeTaskState taskState = (SnapshotUpgradeTaskState) task.getState();
return taskState == null ? SnapshotUpgradeState.LOADING_OLD_STATE : taskState.getState();
return getSnapshotUpgradeState(task);
case DATA_FRAME_ANALYTICS_TASK_NAME:
return getDataFrameAnalyticsState(task);
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState;

import java.net.InetAddress;

Expand Down Expand Up @@ -71,6 +74,27 @@ public void testGetDatefeedState() {
assertEquals(DatafeedState.STARTED, MlTasks.getDatafeedState("foo", tasksBuilder.build()));
}

public void testGetSnapshotUpgradeState() {
PersistentTasksCustomMetadata.Builder tasksBuilder = PersistentTasksCustomMetadata.builder();
// A missing task is a stopped snapshot upgrade
assertEquals(SnapshotUpgradeState.STOPPED, MlTasks.getSnapshotUpgradeState("foo", "1", tasksBuilder.build()));

tasksBuilder.addTask(
MlTasks.snapshotUpgradeTaskId("foo", "1"),
MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME,
new SnapshotUpgradeTaskParams("foo", "1"),
new PersistentTasksCustomMetadata.Assignment("bar", "test assignment")
);
// A task with no state means the datafeed is starting
assertEquals(SnapshotUpgradeState.LOADING_OLD_STATE, MlTasks.getSnapshotUpgradeState("foo", "1", tasksBuilder.build()));

tasksBuilder.updateTaskState(
MlTasks.snapshotUpgradeTaskId("foo", "1"),
new SnapshotUpgradeTaskState(SnapshotUpgradeState.SAVING_NEW_STATE, tasksBuilder.getLastAllocationId(), null)
);
assertEquals(SnapshotUpgradeState.SAVING_NEW_STATE, MlTasks.getSnapshotUpgradeState("foo", "1", tasksBuilder.build()));
}

public void testGetJobTask() {
assertNull(MlTasks.getJobTask("foo", null));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ public void testMLAutoscalingCapacity() throws Exception {
.collect(Collectors.toList());
NativeMemoryCapacity currentScale = MlAutoscalingDeciderService.currentScale(mlNodes, 30, false);
expectedTierBytes = (long) Math.ceil(
(ByteSizeValue.ofMb(50_000 + BASIC_REQUIREMENT_MB + 60_000 + BASELINE_OVERHEAD_MB).getBytes() + currentScale.getTier()) * 100
/ 30.0
(ByteSizeValue.ofMb(50_000 + BASIC_REQUIREMENT_MB + 60_000 + BASELINE_OVERHEAD_MB).getBytes() + currentScale
.getTierMlNativeMemoryRequirement()) * 100 / 30.0
);
expectedNodeBytes = (long) (ByteSizeValue.ofMb(60_000 + BASELINE_OVERHEAD_MB).getBytes() * 100 / 30.0);

Expand Down

0 comments on commit 7d493f6

Please sign in to comment.