Skip to content

Commit

Permalink
Revert "[ML] Use perAllocation and perDeployment memory usage in the …
Browse files Browse the repository at this point in the history
…model assignment planner (#98874)" (#101834)

There were a number of BWC test failures after the PR was merged today. I'll revert it and investigate the failures locally.

Reverts #98874
  • Loading branch information
valeriy42 committed Nov 6, 2023
1 parent 461d004 commit 63f29d4
Show file tree
Hide file tree
Showing 20 changed files with 483 additions and 2,076 deletions.
5 changes: 0 additions & 5 deletions docs/changelog/98874.yaml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.SimpleDiffable;
import org.elasticsearch.common.Randomness;
Expand Down Expand Up @@ -97,10 +96,6 @@ public final class TrainedModelAssignment implements SimpleDiffable<TrainedModel
private final Instant startTime;
private final int maxAssignedAllocations;

public static boolean useNewMemoryFields(TransportVersion minClusterVersion) {
return minClusterVersion.onOrAfter(TransportVersions.V_8_500_064);
}

public static TrainedModelAssignment fromXContent(XContentParser parser) throws IOException {
return PARSER.apply(parser, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,9 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
Expand Down Expand Up @@ -298,23 +296,29 @@ private void modelSizeStats(
for (TrainedModelConfig model : models) {
if (model.getModelType() == TrainedModelType.PYTORCH) {
long totalDefinitionLength = pytorchTotalDefinitionLengthsByModelId.getOrDefault(model.getModelId(), 0L);
// We ensure that in the mixed cluster state trained model stats uses the same values for memory estimation
// as the rebalancer.
boolean useNewMemoryFields = TrainedModelAssignment.useNewMemoryFields(
TransportVersionUtils.getMinTransportVersion(clusterService.state())
);
long estimatedMemoryUsageBytes = totalDefinitionLength > 0L
? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
model.getModelId(),
totalDefinitionLength,
useNewMemoryFields ? model.getPerDeploymentMemoryBytes() : 0,
useNewMemoryFields ? model.getPerAllocationMemoryBytes() : 0,
model.getPerDeploymentMemoryBytes(),
model.getPerAllocationMemoryBytes(),
numberOfAllocations
)
: 0L;
modelSizeStatsByModelId.put(
model.getModelId(),
new TrainedModelSizeStats(totalDefinitionLength, estimatedMemoryUsageBytes)
new TrainedModelSizeStats(
totalDefinitionLength,
totalDefinitionLength > 0L
? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
model.getModelId(),
totalDefinitionLength,
model.getPerDeploymentMemoryBytes(),
model.getPerAllocationMemoryBytes(),
numberOfAllocations
)
: 0L
)
);
} else {
modelSizeStatsByModelId.put(model.getModelId(), new TrainedModelSizeStats(model.getModelSize(), 0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AllocationReducer;
Expand Down Expand Up @@ -77,8 +76,6 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene
private static final TransportVersion RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION = TransportVersions.V_8_3_0;
public static final TransportVersion DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION = TransportVersions.V_8_4_0;

private static final TransportVersion NEW_ALLOCATION_MEMORY_VERSION = TransportVersions.V_8_500_064;

private final ClusterService clusterService;
private final ThreadPool threadPool;
private final NodeLoadDetector nodeLoadDetector;
Expand Down Expand Up @@ -647,14 +644,12 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments(
Map<DiscoveryNode, NodeLoad> nodeLoads = detectNodeLoads(nodes, currentState);
TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.fromState(currentState);

boolean useNewMemoryFields = TrainedModelAssignment.useNewMemoryFields(TransportVersionUtils.getMinTransportVersion(currentState));
TrainedModelAssignmentRebalancer rebalancer = new TrainedModelAssignmentRebalancer(
currentMetadata,
nodeLoads,
nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(currentState),
modelToAdd,
allocatedProcessorsScale,
useNewMemoryFields
allocatedProcessorsScale
);

Set<String> shuttingDownNodeIds = currentState.metadata().nodeShutdowns().getAllNodeIds();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,18 @@ class TrainedModelAssignmentRebalancer {
private final Optional<StartTrainedModelDeploymentAction.TaskParams> deploymentToAdd;
private final int allocatedProcessorsScale;

private final boolean useNewMemoryFields;

TrainedModelAssignmentRebalancer(
TrainedModelAssignmentMetadata currentMetadata,
Map<DiscoveryNode, NodeLoad> nodeLoads,
Map<List<String>, Collection<DiscoveryNode>> mlNodesByZone,
Optional<StartTrainedModelDeploymentAction.TaskParams> deploymentToAdd,
int allocatedProcessorsScale,
boolean useNewMemoryFields
int allocatedProcessorsScale
) {
this.currentMetadata = Objects.requireNonNull(currentMetadata);
this.nodeLoads = Objects.requireNonNull(nodeLoads);
this.mlNodesByZone = Objects.requireNonNull(mlNodesByZone);
this.deploymentToAdd = Objects.requireNonNull(deploymentToAdd);
this.allocatedProcessorsScale = allocatedProcessorsScale;
this.useNewMemoryFields = useNewMemoryFields;
}

TrainedModelAssignmentMetadata.Builder rebalance() {
Expand Down Expand Up @@ -142,11 +138,9 @@ private static void copyAssignments(
AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id());
dest.assignModelToNode(m, originalNode, assignment.getValue());
if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) {
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
// As the node has all its available memory we need to manually account memory of models with
// current allocations.
long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id()));
dest.accountMemory(m, originalNode, requiredMemory);
dest.accountMemory(m, originalNode);
}
}
}
Expand Down Expand Up @@ -174,14 +168,11 @@ private AssignmentPlan computePlanForNormalPriorityModels(
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getTargetAllocations()));
return new AssignmentPlan.Deployment(
assignment.getDeploymentId(),
assignment.getTaskParams().getModelBytes(),
assignment.getTaskParams().estimateMemoryUsageBytes(),
assignment.getTaskParams().getNumberOfAllocations(),
assignment.getTaskParams().getThreadsPerAllocation(),
currentAssignments,
assignment.getMaxAssignedAllocations(),
// in the mixed cluster state use old memory fields to avoid unstable assignment plans
useNewMemoryFields ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0,
useNewMemoryFields ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0
assignment.getMaxAssignedAllocations()
);
})
.forEach(planDeployments::add);
Expand All @@ -190,14 +181,11 @@ private AssignmentPlan computePlanForNormalPriorityModels(
planDeployments.add(
new AssignmentPlan.Deployment(
taskParams.getDeploymentId(),
taskParams.getModelBytes(),
taskParams.estimateMemoryUsageBytes(),
taskParams.getNumberOfAllocations(),
taskParams.getThreadsPerAllocation(),
Map.of(),
0,
// in the mixed cluster state use old memory fields to avoid unstable assignment plans
useNewMemoryFields ? taskParams.getPerDeploymentMemoryBytes() : 0,
useNewMemoryFields ? taskParams.getPerAllocationMemoryBytes() : 0
0
)
);
}
Expand Down Expand Up @@ -229,14 +217,12 @@ private AssignmentPlan computePlanForLowPriorityModels(Set<String> assignableNod
.map(
assignment -> new AssignmentPlan.Deployment(
assignment.getDeploymentId(),
assignment.getTaskParams().getModelBytes(),
assignment.getTaskParams().estimateMemoryUsageBytes(),
assignment.getTaskParams().getNumberOfAllocations(),
assignment.getTaskParams().getThreadsPerAllocation(),
findFittingAssignments(assignment, assignableNodeIds, remainingNodeMemory),
assignment.getMaxAssignedAllocations(),
Priority.LOW,
(useNewMemoryFields == false) ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0,
(useNewMemoryFields == false) ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0
Priority.LOW
)
)
.forEach(planDeployments::add);
Expand All @@ -245,14 +231,12 @@ private AssignmentPlan computePlanForLowPriorityModels(Set<String> assignableNod
planDeployments.add(
new AssignmentPlan.Deployment(
taskParams.getDeploymentId(),
taskParams.getModelBytes(),
taskParams.estimateMemoryUsageBytes(),
taskParams.getNumberOfAllocations(),
taskParams.getThreadsPerAllocation(),
Map.of(),
0,
Priority.LOW,
(useNewMemoryFields == false) ? taskParams.getPerDeploymentMemoryBytes() : 0,
(useNewMemoryFields == false) ? taskParams.getPerAllocationMemoryBytes() : 0
Priority.LOW
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ private Node modifyNodePreservingAllocations(Node n) {
int coresUsed = 0;
for (Deployment m : deployments) {
if (m.currentAllocationsByNodeId().containsKey(n.id())) {
int allocations = m.currentAllocationsByNodeId().get(n.id());
bytesUsed += m.estimateMemoryUsageBytes(allocations);
bytesUsed += m.memoryBytes();
coresUsed += calculateUsedCores(n, m);
}
}
Expand All @@ -59,9 +58,7 @@ Deployment modifyModelPreservingPreviousAssignments(Deployment m) {
m.allocations() - calculatePreservedAllocations(m),
m.threadsPerAllocation(),
calculateAllocationsPerNodeToPreserve(m),
m.maxAssignedAllocations(),
m.perDeploymentMemoryBytes(),
m.perAllocationMemoryBytes()
m.maxAssignedAllocations()
);
}

Expand All @@ -70,37 +67,28 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) {
// they will not match the models/nodes members we have in this class.
// Therefore, we build a lookup table based on the ids so we can merge the plan
// with its preserved allocations.
final Map<Tuple<String, String>, Integer> plannedAssignmentsByModelNodeIdPair = new HashMap<>();
final Map<Tuple<String, String>, Integer> assignmentsByModelNodeIdPair = new HashMap<>();
for (Deployment m : assignmentPlan.models()) {
Map<Node, Integer> assignments = assignmentPlan.assignments(m).orElse(Map.of());
for (Map.Entry<Node, Integer> nodeAssignment : assignments.entrySet()) {
plannedAssignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue());
assignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue());
}
}

AssignmentPlan.Builder mergedPlanBuilder = AssignmentPlan.builder(nodes, deployments);
for (Node n : nodes) {
// TODO (#101612) Should the first loop happen in the builder constructor?
for (Deployment deploymentAllocationsToPreserve : deployments) {

// if the model m is already allocated on the node n and I want to preserve this allocation
int preservedAllocations = addPreservedAllocations(n, deploymentAllocationsToPreserve);
if (preservedAllocations > 0) {
long requiredMemory = deploymentAllocationsToPreserve.estimateMemoryUsageBytes(preservedAllocations);
if (mergedPlanBuilder.canAssign(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory)) {
mergedPlanBuilder.assignModelToNode(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory);
for (Deployment m : deployments) {
for (Node n : nodes) {
int allocations = assignmentsByModelNodeIdPair.getOrDefault(Tuple.tuple(m.id(), n.id()), 0);
if (m.currentAllocationsByNodeId().containsKey(n.id())) {
if (mergedPlanBuilder.getRemainingMemory(n) >= m.memoryBytes()) {
allocations += addPreservedAllocations(n, m);
// As the node has all its available memory we need to manually account memory of models with
// current allocations.
mergedPlanBuilder.accountMemory(m, n);
}
}
}
for (Deployment deploymentNewAllocations : deployments) {
int newAllocations = plannedAssignmentsByModelNodeIdPair.getOrDefault(
Tuple.tuple(deploymentNewAllocations.id(), n.id()),
0
);

long requiredMemory = mergedPlanBuilder.getDeploymentMemoryRequirement(deploymentNewAllocations, n, newAllocations);
if (newAllocations > 0 && mergedPlanBuilder.canAssign(deploymentNewAllocations, n, newAllocations, requiredMemory)) {
mergedPlanBuilder.assignModelToNode(deploymentNewAllocations, n, newAllocations);
if (allocations > 0) {
mergedPlanBuilder.assignModelToNode(m, n, allocations);
}
}
}
Expand Down

0 comments on commit 63f29d4

Please sign in to comment.