Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Use perAllocation and perDeployment memory usage in the model assignment planner #98874

Merged
merged 42 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
08e3c53
comments on work scope
valeriy42 Aug 24, 2023
96c15d3
add memory estimation to AssignmentPlan.Deployment
valeriy42 Aug 25, 2023
6e00431
Updated linear solver and rounding routines
valeriy42 Aug 25, 2023
259f883
fix unit test compilation errors
valeriy42 Aug 25, 2023
613b92e
extend unit tests
valeriy42 Aug 30, 2023
dc8070d
change memoryUsage to memoryBytes in deployments
valeriy42 Aug 31, 2023
9f83f09
fixing original unit tests
valeriy42 Sep 7, 2023
3ba327c
Unit test for down scaling cluster
valeriy42 Oct 6, 2023
be7cfcd
optimal allocation test works
valeriy42 Oct 9, 2023
0d6a11b
extend unit tests with new memory fields
valeriy42 Oct 9, 2023
2a58efd
formatting
valeriy42 Oct 9, 2023
f9c338b
Merge branch 'main' of https://github.com/elastic/elasticsearch into …
valeriy42 Oct 9, 2023
71756f8
Update docs/changelog/98874.yaml
valeriy42 Oct 9, 2023
6badb52
Update .gitignore
valeriy42 Oct 9, 2023
d9ba0bd
remove dead code
valeriy42 Oct 9, 2023
a348b58
Merge branch 'main' of https://github.com/elastic/elasticsearch into …
valeriy42 Oct 9, 2023
4089462
fix unit test after merging
valeriy42 Oct 9, 2023
c11d11a
spotless check
valeriy42 Oct 9, 2023
697c321
Fix failing tests in AssignmentPlannerTests
valeriy42 Oct 13, 2023
e1fdf55
AssignmentPlannerTests all tests green
valeriy42 Oct 13, 2023
786f68d
fixed AssignmentPlanTests
valeriy42 Oct 13, 2023
76880dc
fixed ZoneAwareAssignmentPlanner
valeriy42 Oct 13, 2023
4b12f3f
PreserveOneAllocationTests fixed
valeriy42 Oct 13, 2023
6f531d5
PreserveAllAllocationsTests fixed
valeriy42 Oct 13, 2023
3a763ea
formatting
valeriy42 Oct 13, 2023
e74374a
Remove assertion for nodes preserving allocations
valeriy42 Oct 16, 2023
8c33d6d
Merge branch 'main' into update-assignment-planner
elasticmachine Oct 16, 2023
85b1f91
fixing unit tests
valeriy42 Oct 23, 2023
446a835
Merge branch 'main' into update-assignment-planner
elasticmachine Oct 24, 2023
fe80f41
all tests green
valeriy42 Oct 25, 2023
35e0ccd
formatting
valeriy42 Oct 26, 2023
e22d648
fix TrainedModelAssignmentRebalancerTests
valeriy42 Oct 26, 2023
8e278d7
remove dead code
valeriy42 Oct 26, 2023
60b53ad
extend unit tests with new memory format cases
valeriy42 Oct 27, 2023
50d6d2e
extend integration tests to test memory estimation in mixed cluster s…
valeriy42 Oct 31, 2023
6cefe3f
formatting
valeriy42 Oct 31, 2023
d1373e7
Merge branch 'main' of https://github.com/elastic/elasticsearch into …
valeriy42 Oct 31, 2023
8b3963d
add references to the refactoring issue
valeriy42 Oct 31, 2023
b3eb3f4
fix integration test
valeriy42 Oct 31, 2023
92753e8
fix forbidden api check
valeriy42 Oct 31, 2023
b8f3511
assign models to node only when possible
valeriy42 Nov 2, 2023
ccf5555
Merge branch 'main' into update-assignment-planner
elasticmachine Nov 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/98874.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 98874
summary: Estimate the memory required to deploy trained models more accurately
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.SimpleDiffable;
import org.elasticsearch.common.Randomness;
Expand Down Expand Up @@ -96,6 +97,10 @@ public class TrainedModelAssignment implements SimpleDiffable<TrainedModelAssign
private final Instant startTime;
private final int maxAssignedAllocations;

public static boolean useNewMemoryFields(TransportVersion minClusterVersion) {
return minClusterVersion.onOrAfter(TransportVersions.V_8_500_064);
}

public static TrainedModelAssignment fromXContent(XContentParser parser) throws IOException {
return PARSER.apply(parser, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
Expand Down Expand Up @@ -296,29 +298,23 @@ private void modelSizeStats(
for (TrainedModelConfig model : models) {
if (model.getModelType() == TrainedModelType.PYTORCH) {
long totalDefinitionLength = pytorchTotalDefinitionLengthsByModelId.getOrDefault(model.getModelId(), 0L);
// We ensure that in the mixed cluster state trained model stats uses the same values for memory estimation
// as the rebalancer.
boolean useNewMemoryFields = TrainedModelAssignment.useNewMemoryFields(
TransportVersionUtils.getMinTransportVersion(clusterService.state())
);
long estimatedMemoryUsageBytes = totalDefinitionLength > 0L
? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
model.getModelId(),
totalDefinitionLength,
model.getPerDeploymentMemoryBytes(),
model.getPerAllocationMemoryBytes(),
useNewMemoryFields ? model.getPerDeploymentMemoryBytes() : 0,
useNewMemoryFields ? model.getPerAllocationMemoryBytes() : 0,
numberOfAllocations
)
: 0L;
modelSizeStatsByModelId.put(
model.getModelId(),
new TrainedModelSizeStats(
totalDefinitionLength,
totalDefinitionLength > 0L
? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
model.getModelId(),
totalDefinitionLength,
model.getPerDeploymentMemoryBytes(),
model.getPerAllocationMemoryBytes(),
numberOfAllocations
)
: 0L
)
new TrainedModelSizeStats(totalDefinitionLength, estimatedMemoryUsageBytes)
);
} else {
modelSizeStatsByModelId.put(model.getModelId(), new TrainedModelSizeStats(model.getModelSize(), 0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AllocationReducer;
Expand Down Expand Up @@ -76,6 +77,8 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene
private static final TransportVersion RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION = TransportVersions.V_8_3_0;
public static final TransportVersion DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION = TransportVersions.V_8_4_0;

private static final TransportVersion NEW_ALLOCATION_MEMORY_VERSION = TransportVersions.V_8_500_064;

private final ClusterService clusterService;
private final ThreadPool threadPool;
private final NodeLoadDetector nodeLoadDetector;
Expand Down Expand Up @@ -644,12 +647,14 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments(
Map<DiscoveryNode, NodeLoad> nodeLoads = detectNodeLoads(nodes, currentState);
TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.fromState(currentState);

boolean useNewMemoryFields = TrainedModelAssignment.useNewMemoryFields(TransportVersionUtils.getMinTransportVersion(currentState));
TrainedModelAssignmentRebalancer rebalancer = new TrainedModelAssignmentRebalancer(
currentMetadata,
nodeLoads,
nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(currentState),
modelToAdd,
allocatedProcessorsScale
allocatedProcessorsScale,
useNewMemoryFields
);

Set<String> shuttingDownNodeIds = currentState.metadata().nodeShutdowns().getAllNodeIds();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,22 @@ class TrainedModelAssignmentRebalancer {
private final Optional<StartTrainedModelDeploymentAction.TaskParams> deploymentToAdd;
private final int allocatedProcessorsScale;

private final boolean useNewMemoryFields;

TrainedModelAssignmentRebalancer(
TrainedModelAssignmentMetadata currentMetadata,
Map<DiscoveryNode, NodeLoad> nodeLoads,
Map<List<String>, Collection<DiscoveryNode>> mlNodesByZone,
Optional<StartTrainedModelDeploymentAction.TaskParams> deploymentToAdd,
int allocatedProcessorsScale
int allocatedProcessorsScale,
boolean useNewMemoryFields
) {
this.currentMetadata = Objects.requireNonNull(currentMetadata);
this.nodeLoads = Objects.requireNonNull(nodeLoads);
this.mlNodesByZone = Objects.requireNonNull(mlNodesByZone);
this.deploymentToAdd = Objects.requireNonNull(deploymentToAdd);
this.allocatedProcessorsScale = allocatedProcessorsScale;
this.useNewMemoryFields = useNewMemoryFields;
}

TrainedModelAssignmentMetadata.Builder rebalance() {
Expand Down Expand Up @@ -138,9 +142,11 @@ private static void copyAssignments(
AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id());
dest.assignModelToNode(m, originalNode, assignment.getValue());
if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) {
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
// As the node has all its available memory we need to manually account memory of models with
// current allocations.
dest.accountMemory(m, originalNode);
long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id()));
dest.accountMemory(m, originalNode, requiredMemory);
}
}
}
Expand Down Expand Up @@ -168,11 +174,14 @@ private AssignmentPlan computePlanForNormalPriorityModels(
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getTargetAllocations()));
return new AssignmentPlan.Deployment(
assignment.getDeploymentId(),
assignment.getTaskParams().estimateMemoryUsageBytes(),
assignment.getTaskParams().getModelBytes(),
assignment.getTaskParams().getNumberOfAllocations(),
assignment.getTaskParams().getThreadsPerAllocation(),
currentAssignments,
assignment.getMaxAssignedAllocations()
assignment.getMaxAssignedAllocations(),
// in the mixed cluster state use old memory fields to avoid unstable assignment plans
useNewMemoryFields ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0,
useNewMemoryFields ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0
);
})
.forEach(planDeployments::add);
Expand All @@ -181,11 +190,14 @@ private AssignmentPlan computePlanForNormalPriorityModels(
planDeployments.add(
new AssignmentPlan.Deployment(
taskParams.getDeploymentId(),
taskParams.estimateMemoryUsageBytes(),
taskParams.getModelBytes(),
taskParams.getNumberOfAllocations(),
taskParams.getThreadsPerAllocation(),
Map.of(),
0
0,
// in the mixed cluster state use old memory fields to avoid unstable assignment plans
useNewMemoryFields ? taskParams.getPerDeploymentMemoryBytes() : 0,
useNewMemoryFields ? taskParams.getPerAllocationMemoryBytes() : 0
)
);
}
Expand Down Expand Up @@ -217,12 +229,14 @@ private AssignmentPlan computePlanForLowPriorityModels(Set<String> assignableNod
.map(
assignment -> new AssignmentPlan.Deployment(
assignment.getDeploymentId(),
assignment.getTaskParams().estimateMemoryUsageBytes(),
assignment.getTaskParams().getModelBytes(),
assignment.getTaskParams().getNumberOfAllocations(),
assignment.getTaskParams().getThreadsPerAllocation(),
findFittingAssignments(assignment, assignableNodeIds, remainingNodeMemory),
assignment.getMaxAssignedAllocations(),
Priority.LOW
Priority.LOW,
(useNewMemoryFields == false) ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0,
(useNewMemoryFields == false) ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0
)
)
.forEach(planDeployments::add);
Expand All @@ -231,12 +245,14 @@ private AssignmentPlan computePlanForLowPriorityModels(Set<String> assignableNod
planDeployments.add(
new AssignmentPlan.Deployment(
taskParams.getDeploymentId(),
taskParams.estimateMemoryUsageBytes(),
taskParams.getModelBytes(),
taskParams.getNumberOfAllocations(),
taskParams.getThreadsPerAllocation(),
Map.of(),
0,
Priority.LOW
Priority.LOW,
(useNewMemoryFields == false) ? taskParams.getPerDeploymentMemoryBytes() : 0,
(useNewMemoryFields == false) ? taskParams.getPerAllocationMemoryBytes() : 0
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ private Node modifyNodePreservingAllocations(Node n) {
int coresUsed = 0;
for (Deployment m : deployments) {
if (m.currentAllocationsByNodeId().containsKey(n.id())) {
bytesUsed += m.memoryBytes();
int allocations = m.currentAllocationsByNodeId().get(n.id());
bytesUsed += m.estimateMemoryUsageBytes(allocations);
coresUsed += calculateUsedCores(n, m);
}
}
Expand All @@ -58,7 +59,9 @@ Deployment modifyModelPreservingPreviousAssignments(Deployment m) {
m.allocations() - calculatePreservedAllocations(m),
m.threadsPerAllocation(),
calculateAllocationsPerNodeToPreserve(m),
m.maxAssignedAllocations()
m.maxAssignedAllocations(),
m.perDeploymentMemoryBytes(),
m.perAllocationMemoryBytes()
);
}

Expand All @@ -67,28 +70,37 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) {
// they will not match the models/nodes members we have in this class.
// Therefore, we build a lookup table based on the ids so we can merge the plan
// with its preserved allocations.
final Map<Tuple<String, String>, Integer> assignmentsByModelNodeIdPair = new HashMap<>();
final Map<Tuple<String, String>, Integer> plannedAssignmentsByModelNodeIdPair = new HashMap<>();
for (Deployment m : assignmentPlan.models()) {
Map<Node, Integer> assignments = assignmentPlan.assignments(m).orElse(Map.of());
for (Map.Entry<Node, Integer> nodeAssignment : assignments.entrySet()) {
assignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue());
plannedAssignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue());
}
}

AssignmentPlan.Builder mergedPlanBuilder = AssignmentPlan.builder(nodes, deployments);
for (Deployment m : deployments) {
for (Node n : nodes) {
int allocations = assignmentsByModelNodeIdPair.getOrDefault(Tuple.tuple(m.id(), n.id()), 0);
if (m.currentAllocationsByNodeId().containsKey(n.id())) {
if (mergedPlanBuilder.getRemainingMemory(n) >= m.memoryBytes()) {
allocations += addPreservedAllocations(n, m);
// As the node has all its available memory we need to manually account memory of models with
// current allocations.
mergedPlanBuilder.accountMemory(m, n);
for (Node n : nodes) {
// TODO (#101612) Should the first loop happen in the builder constructor?
for (Deployment deploymentAllocationsToPreserve : deployments) {

// if the model m is already allocated on the node n and I want to preserve this allocation
int preservedAllocations = addPreservedAllocations(n, deploymentAllocationsToPreserve);
if (preservedAllocations > 0) {
long requiredMemory = deploymentAllocationsToPreserve.estimateMemoryUsageBytes(preservedAllocations);
if (mergedPlanBuilder.canAssign(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory)) {
mergedPlanBuilder.assignModelToNode(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory);
}
}
if (allocations > 0) {
mergedPlanBuilder.assignModelToNode(m, n, allocations);
}
for (Deployment deploymentNewAllocations : deployments) {
int newAllocations = plannedAssignmentsByModelNodeIdPair.getOrDefault(
Tuple.tuple(deploymentNewAllocations.id(), n.id()),
0
);

long requiredMemory = mergedPlanBuilder.getDeploymentMemoryRequirement(deploymentNewAllocations, n, newAllocations);
if (newAllocations > 0 && mergedPlanBuilder.canAssign(deploymentNewAllocations, n, newAllocations, requiredMemory)) {
mergedPlanBuilder.assignModelToNode(deploymentNewAllocations, n, newAllocations);
}
}
}
Expand Down