Skip to content

Commit

Permalink
[ML] Distribute trained model allocations across availability zones (#…
Browse files Browse the repository at this point in the history
…89822)

When a model deployment is started with 2 or more allocations
and availability zones are present we should distribute the allocations
across availability zones so that there is resilience.

This commit adds a `ZoneAwareAssignmentPlanner` that attempts to evenly
distribute the allocations of a deployment across the available zones.
  • Loading branch information
dimitris-athanasiou committed Sep 7, 2022
1 parent fc64b2c commit c733eb8
Show file tree
Hide file tree
Showing 10 changed files with 760 additions and 110 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/89822.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 89822
summary: Distribute trained model allocations across availability zones
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,8 @@ public Collection<Object> createComponents(
clusterService,
threadPool,
new NodeLoadDetector(memoryTracker),
new SystemAuditor(client, clusterService)
new SystemAuditor(client, clusterService),
nodeAvailabilityZoneMapper
)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
import org.elasticsearch.xpack.ml.job.NodeLoad;
import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
import org.elasticsearch.xpack.ml.notifications.SystemAuditor;
Expand Down Expand Up @@ -69,6 +70,7 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene
private final ThreadPool threadPool;
private final NodeLoadDetector nodeLoadDetector;
private final SystemAuditor systemAuditor;
private final NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper;
private volatile int maxMemoryPercentage;
private volatile boolean useAuto;
private volatile int maxOpenJobs;
Expand All @@ -78,12 +80,14 @@ public TrainedModelAssignmentClusterService(
ClusterService clusterService,
ThreadPool threadPool,
NodeLoadDetector nodeLoadDetector,
SystemAuditor systemAuditor
SystemAuditor systemAuditor,
NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper
) {
this.clusterService = Objects.requireNonNull(clusterService);
this.threadPool = Objects.requireNonNull(threadPool);
this.nodeLoadDetector = Objects.requireNonNull(nodeLoadDetector);
this.systemAuditor = Objects.requireNonNull(systemAuditor);
this.nodeAvailabilityZoneMapper = Objects.requireNonNull(nodeAvailabilityZoneMapper);
this.maxMemoryPercentage = MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings);
this.useAuto = MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT.get(settings);
this.maxOpenJobs = MachineLearning.MAX_OPEN_JOBS_PER_NODE.get(settings);
Expand Down Expand Up @@ -462,6 +466,7 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments(
TrainedModelAssignmentRebalancer rebalancer = new TrainedModelAssignmentRebalancer(
TrainedModelAssignmentMetadata.fromState(currentState),
nodeLoads,
nodeAvailabilityZoneMapper,
modelToAdd
);
return rebalancer.rebalance();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlanner;
import org.elasticsearch.xpack.ml.inference.assignment.planning.ZoneAwareAssignmentPlanner;
import org.elasticsearch.xpack.ml.job.NodeLoad;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -41,15 +43,18 @@ class TrainedModelAssignmentRebalancer {

private final TrainedModelAssignmentMetadata currentMetadata;
private final Map<DiscoveryNode, NodeLoad> nodeLoads;
private final NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper;
private final Optional<StartTrainedModelDeploymentAction.TaskParams> modelToAdd;

TrainedModelAssignmentRebalancer(
TrainedModelAssignmentMetadata currentMetadata,
Map<DiscoveryNode, NodeLoad> nodeLoads,
NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper,
Optional<StartTrainedModelDeploymentAction.TaskParams> modelToAdd
) {
this.currentMetadata = Objects.requireNonNull(currentMetadata);
this.nodeLoads = Objects.requireNonNull(nodeLoads);
this.nodeAvailabilityZoneMapper = Objects.requireNonNull(nodeAvailabilityZoneMapper);
this.modelToAdd = Objects.requireNonNull(modelToAdd);
}

Expand Down Expand Up @@ -78,24 +83,16 @@ private boolean areAllModelsSatisfiedAndNoOutdatedRoutingEntries() {
}

AssignmentPlan computeAssignmentPlan() {
List<AssignmentPlan.Node> planNodes = nodeLoads.entrySet()
.stream()
.filter(e -> Strings.isNullOrEmpty(e.getValue().getError()))
.map(
e -> new AssignmentPlan.Node(
e.getKey().getId(),
// We subtract native inference memory as the planner expects available memory for
// native inference including current assignments.
getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(e.getValue()),
getNodeAllocatedProcessors(e.getKey()).orElse(0)
)
)
.toList();
final Map<List<String>, List<AssignmentPlan.Node>> nodesByZone = createNodesByZoneMap();

final List<AssignmentPlan.Model> planModels = new ArrayList<>(
currentMetadata.modelAssignments().size() + (modelToAdd.isPresent() ? 1 : 0)
);
final Set<String> assignableNodeIds = planNodes.stream().map(AssignmentPlan.Node::id).collect(Collectors.toSet());
final Set<String> assignableNodeIds = nodesByZone.values()
.stream()
.flatMap(List::stream)
.map(AssignmentPlan.Node::id)
.collect(Collectors.toSet());
currentMetadata.modelAssignments().values().stream().map(assignment -> {
Map<String, Integer> currentAssignments = assignment.getNodeRoutingTable()
.entrySet()
Expand Down Expand Up @@ -127,7 +124,38 @@ AssignmentPlan computeAssignmentPlan() {
)
)
);
return new AssignmentPlanner(planNodes, planModels).computePlan();
return new ZoneAwareAssignmentPlanner(nodesByZone, planModels).computePlan();
}

private Map<List<String>, List<AssignmentPlan.Node>> createNodesByZoneMap() {
Map<List<String>, Collection<DiscoveryNode>> mlNodesByZone = nodeAvailabilityZoneMapper.getMlNodesByAvailabilityZone();
return mlNodesByZone.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> {
Collection<DiscoveryNode> discoveryNodes = e.getValue();
List<AssignmentPlan.Node> nodes = new ArrayList<>();
for (DiscoveryNode discoveryNode : discoveryNodes) {
if (nodeLoads.containsKey(discoveryNode)) {
NodeLoad load = nodeLoads.get(discoveryNode);
if (Strings.isNullOrEmpty(load.getError())) {
nodes.add(
new AssignmentPlan.Node(
discoveryNode.getId(),
// We subtract native inference memory as the planner expects available memory for
// native inference including current assignments.
getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(load),
getNodeAllocatedProcessors(discoveryNode).orElse(0)
)
);
} else {
logger.warn(
format("ignoring node [%s] as detecting its load failed with [%s]", discoveryNode.getId(), load.getError())
);
}
} else {
logger.warn(format("ignoring node [%s] as no load could be detected", discoveryNode.getId()));
}
}
return nodes;
}));
}

private static OptionalInt getNodeAllocatedProcessors(DiscoveryNode node) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@
* attempt to find a solution that provides at least one allocation to
* previously assigned models.
*/
public class AssignmentPlanner {
class AssignmentPlanner {

private static final Logger logger = LogManager.getLogger(AssignmentPlanner.class);

private final List<Node> nodes;
private final List<Model> models;

public AssignmentPlanner(List<Node> nodes, List<Model> models) {
AssignmentPlanner(List<Node> nodes, List<Model> models) {
this.nodes = nodes.stream().sorted(Comparator.comparing(Node::id)).toList();
this.models = models.stream().sorted(Comparator.comparing(Model::id)).toList();
}
Expand All @@ -58,7 +58,7 @@ public AssignmentPlan computePlan() {
return computePlan(true);
}

private AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
public AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
logger.debug(() -> format("Computing plan for nodes = %s; models = %s", nodes, models));

AssignmentPlan bestPlan;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.assignment.planning;

import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Model;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.elasticsearch.core.Strings.format;

/**
* An assignment planner that is aware of availability zones and tries to distribute
* model allocations evenly across zones in order to achieve better resilience in the
* case nodes in a particular zone become unavailable.
*/
public class ZoneAwareAssignmentPlanner {

private static final Logger logger = LogManager.getLogger(ZoneAwareAssignmentPlanner.class);

/**
* A map from zone attributes to node.
*/
private final Map<List<String>, List<Node>> nodesByZone;

private final List<Model> models;

public ZoneAwareAssignmentPlanner(Map<List<String>, List<Node>> nodesByZone, List<Model> models) {
this.nodesByZone = sortByZone(Objects.requireNonNull(nodesByZone));
this.models = Objects.requireNonNull(models);
}

private static Map<List<String>, List<Node>> sortByZone(Map<List<String>, List<Node>> nodesByZone) {
Map<List<String>, List<Node>> sortedByZone = new TreeMap<>(
Comparator.comparing(zoneAttributes -> zoneAttributes.stream().collect(Collectors.joining()))
);
sortedByZone.putAll(nodesByZone);
return sortedByZone;
}

public AssignmentPlan computePlan() {
// There is only one zone; we can optimize and compute a plan directly.
if (nodesByZone.size() == 1) {
return new AssignmentPlanner(nodesByZone.values().iterator().next(), models).computePlan(true);
}

// First we try to compute a plan without forcing assigning previously assigned models as this may
// produce better plans. If that plan has failed to assign previously assigned models we then try
// again this time prioritizing assigning such models.
AssignmentPlan plan = computePlan(false);
if (plan.arePreviouslyAssignedModelsAssigned() == false) {
plan = computePlan(true);
}
return plan;
}

private AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
logger.debug(
() -> format(
"computing plan%s trying to assign previously assigned models",
tryAssigningPreviouslyAssignedModels ? "" : " without"
)
);
// The idea here is that we solve per zone trying to distribute allocations evenly.
// After computing a plan for each zone it is possible that there are still unsatisfied allocations
// that can be allocated, so we solve a final time across all zones preserving the allocations we
// allocated on the first per zone assignment plans.

int remainingZones = nodesByZone.size();
Map<String, Integer> modelIdToRemainingAllocations = models.stream().collect(Collectors.toMap(Model::id, Model::allocations));
List<AssignmentPlan> plans = new ArrayList<>();
for (var zoneToNodes : nodesByZone.entrySet()) {
logger.debug(() -> format("computing plan for availability zone %s", zoneToNodes.getKey()));
AssignmentPlan plan = computeZonePlan(
zoneToNodes.getValue(),
modelIdToRemainingAllocations,
remainingZones,
tryAssigningPreviouslyAssignedModels
);
plan.models()
.forEach(
m -> modelIdToRemainingAllocations.computeIfPresent(
m.id(),
(modelId, remainingAllocations) -> remainingAllocations - plan.totalAllocations(m)
)
);
plans.add(plan);
remainingZones--;
}
AssignmentPlan plan = computePlanAcrossAllNodes(plans);
logger.debug(() -> "Zone aware plan =\n" + plan.prettyPrint());
return plan;
}

private AssignmentPlan computeZonePlan(
List<Node> nodes,
Map<String, Integer> modelIdToRemainingAllocations,
int remainingZones,
boolean tryAssigningPreviouslyAssignedModels
) {
Map<String, Integer> modelIdToTargetAllocations = modelIdToRemainingAllocations.entrySet()
.stream()
.filter(e -> e.getValue() > 0)
.collect(Collectors.toMap(e -> e.getKey(), e -> (e.getValue() - 1) / remainingZones + 1));

List<Model> modifiedModels = models.stream()
.filter(m -> modelIdToTargetAllocations.getOrDefault(m.id(), 0) > 0)
.map(
m -> new Model(
m.id(),
m.memoryBytes(),
modelIdToTargetAllocations.get(m.id()),
m.threadsPerAllocation(),
m.currentAllocationsByNodeId(),
// Only force assigning at least once previously assigned models that have not had any allocation yet
(tryAssigningPreviouslyAssignedModels && modelIdToRemainingAllocations.get(m.id()) == m.allocations())
? m.maxAssignedAllocations()
: 0
)
)
.toList();
return new AssignmentPlanner(nodes, modifiedModels).computePlan(tryAssigningPreviouslyAssignedModels);
}

private AssignmentPlan computePlanAcrossAllNodes(List<AssignmentPlan> plans) {
logger.debug(() -> "computing plan across all nodes");
final List<Node> allNodes = new ArrayList<>();
nodesByZone.values().forEach(allNodes::addAll);

Map<String, Map<String, Integer>> allocationsByNodeIdByModelId = mergeAllocationsByNodeIdByModelId(plans);

List<Model> modelsAccountingPlans = models.stream()
.map(
m -> new Model(
m.id(),
m.memoryBytes(),
m.allocations(),
m.threadsPerAllocation(),
allocationsByNodeIdByModelId.get(m.id()),
m.maxAssignedAllocations()
)
)
.toList();

PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations(allNodes, modelsAccountingPlans);
List<Node> planNodes = preserveAllAllocations.nodesPreservingAllocations();
List<Model> planModels = preserveAllAllocations.modelsPreservingAllocations();
AssignmentPlan plan = new LinearProgrammingPlanSolver(planNodes, planModels).solvePlan(false);
plan = preserveAllAllocations.mergePreservedAllocations(plan);
return swapOriginalModelsInPlan(plan, allNodes, modelsAccountingPlans);
}

private AssignmentPlan swapOriginalModelsInPlan(AssignmentPlan plan, List<Node> allNodes, List<Model> planModels) {
final Map<String, Model> originalModelById = models.stream().collect(Collectors.toMap(Model::id, Function.identity()));
final Map<String, Node> originalNodeById = allNodes.stream().collect(Collectors.toMap(Node::id, Function.identity()));
AssignmentPlan.Builder planBuilder = AssignmentPlan.builder(allNodes, models);
for (Model m : planModels) {
Optional<Map<Node, Integer>> nodeAssignments = plan.assignments(m);
if (nodeAssignments.isPresent()) {
nodeAssignments.get()
.entrySet()
.forEach(
e -> planBuilder.assignModelToNode(
originalModelById.get(m.id()),
originalNodeById.get(e.getKey().id()),
e.getValue()
)
);
}
}
return planBuilder.build();
}

private Map<String, Map<String, Integer>> mergeAllocationsByNodeIdByModelId(List<AssignmentPlan> plans) {
Map<String, Map<String, Integer>> allocationsByNodeIdByModelId = new HashMap<>();
models.forEach(m -> allocationsByNodeIdByModelId.put(m.id(), new HashMap<>()));
for (AssignmentPlan plan : plans) {
for (Model m : plan.models()) {
Map<String, Integer> nodeIdToAllocations = allocationsByNodeIdByModelId.get(m.id());
Optional<Map<Node, Integer>> assignments = plan.assignments(m);
if (assignments.isPresent()) {
for (Map.Entry<Node, Integer> nodeAssignments : assignments.get().entrySet()) {
nodeIdToAllocations.compute(
nodeAssignments.getKey().id(),
(nodeId, existingAllocations) -> existingAllocations == null
? nodeAssignments.getValue()
: existingAllocations + nodeAssignments.getValue()
);
}
}
}
}
return allocationsByNodeIdByModelId;
}
}

0 comments on commit c733eb8

Please sign in to comment.