Skip to content

Commit

Permalink
[ML] Improve scalability of NLP models (#87366)
Browse files Browse the repository at this point in the history
We can improve a model's latency and throughput by using more
threads for each inference request. We can also improve its
throughput by processing multiple inference requests in parallel.
A user can set the model's `number_of_allocations` and `threads_per_allocation`
settings to increase performance.
However, if we use more threads than the node's allocated processors,
we end up with thread oversubscription and performance deteriorates throughout.

This commit changes the way model allocations are distributed across the
ML nodes of the cluster. Up to now, we were trying to allocate each model
on every node. Now, we make use of the `AssignmentPlanner` class introduced
in #86004 in order to compute an assignment plan that distributes model
allocations across the cluster while we maximize the number of allocations
we provide to each model without oversubscribing the nodes' CPU.
  • Loading branch information
dimitris-athanasiou committed Jun 22, 2022
1 parent f98a02d commit a9d1fa2
Show file tree
Hide file tree
Showing 50 changed files with 2,610 additions and 1,055 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/87366.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 87366
summary: Improve scalability of NLP models
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,36 @@
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfoUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;

public class UpdateTrainedModelAssignmentStateAction extends ActionType<AcknowledgedResponse> {
public static final UpdateTrainedModelAssignmentStateAction INSTANCE = new UpdateTrainedModelAssignmentStateAction();
public class UpdateTrainedModelAssignmentRoutingInfoAction extends ActionType<AcknowledgedResponse> {
public static final UpdateTrainedModelAssignmentRoutingInfoAction INSTANCE = new UpdateTrainedModelAssignmentRoutingInfoAction();
public static final String NAME = "cluster:internal/xpack/ml/model_allocation/update";

private UpdateTrainedModelAssignmentStateAction() {
private UpdateTrainedModelAssignmentRoutingInfoAction() {
super(NAME, AcknowledgedResponse::readFrom);
}

public static class Request extends MasterNodeRequest<Request> {
private final String nodeId;
private final String modelId;
private final RoutingStateAndReason routingState;
private final RoutingInfoUpdate update;

public Request(String nodeId, String modelId, RoutingStateAndReason routingState) {
public Request(String nodeId, String modelId, RoutingInfoUpdate update) {
this.nodeId = ExceptionsHelper.requireNonNull(nodeId, "node_id");
this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
this.routingState = ExceptionsHelper.requireNonNull(routingState, "routing_state");
this.update = ExceptionsHelper.requireNonNull(update, "update");
}

public Request(StreamInput in) throws IOException {
super(in);
this.nodeId = in.readString();
this.modelId = in.readString();
this.routingState = new RoutingStateAndReason(in);
this.update = new RoutingInfoUpdate(in);
}

public String getNodeId() {
Expand All @@ -53,8 +53,8 @@ public String getModelId() {
return modelId;
}

public RoutingStateAndReason getRoutingState() {
return routingState;
public RoutingInfoUpdate getUpdate() {
return update;
}

@Override
Expand All @@ -67,7 +67,7 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(nodeId);
out.writeString(modelId);
routingState.writeTo(out);
update.writeTo(out);
}

@Override
Expand All @@ -77,17 +77,17 @@ public boolean equals(Object o) {
Request request = (Request) o;
return Objects.equals(nodeId, request.nodeId)
&& Objects.equals(modelId, request.modelId)
&& Objects.equals(routingState, request.routingState);
&& Objects.equals(update, request.update);
}

@Override
public int hashCode() {
return Objects.hash(nodeId, modelId, routingState);
return Objects.hash(nodeId, modelId, update);
}

@Override
public String toString() {
return "Request{" + "nodeId='" + nodeId + '\'' + ", modelId='" + modelId + '\'' + ", routingState=" + routingState + '}';
return "Request{" + "nodeId='" + nodeId + '\'' + ", modelId='" + modelId + '\'' + ", update=" + update + '}';
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.assignment;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;

public class RoutingInfo implements ToXContentObject, Writeable {

private static final ParseField CURRENT_ALLOCATIONS = new ParseField("current_allocations");
private static final ParseField TARGET_ALLOCATIONS = new ParseField("target_allocations");
private static final ParseField ROUTING_STATE = new ParseField("routing_state");
private static final ParseField REASON = new ParseField("reason");

private static final ConstructingObjectParser<RoutingInfo, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_routing_state",
a -> new RoutingInfo((Integer) a[0], (Integer) a[1], RoutingState.fromString((String) a[2]), (String) a[3])
);
static {
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), CURRENT_ALLOCATIONS);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), TARGET_ALLOCATIONS);
PARSER.declareString(ConstructingObjectParser.constructorArg(), ROUTING_STATE);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON);
}

public static RoutingInfo fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final int currentAllocations;
private final int targetAllocations;
private final RoutingState state;
private final String reason;

// There may be objects in cluster state prior to 8.4 that do not contain values for currentAllocations and targetAllocations.
private RoutingInfo(
@Nullable Integer currentAllocations,
@Nullable Integer targetAllocations,
RoutingState state,
@Nullable String reason
) {
this(currentAllocations == null ? 0 : currentAllocations, targetAllocations == null ? 0 : targetAllocations, state, reason);
}

public RoutingInfo(int currentAllocations, int targetAllocations, RoutingState state, String reason) {
this.currentAllocations = currentAllocations;
this.targetAllocations = targetAllocations;
this.state = ExceptionsHelper.requireNonNull(state, ROUTING_STATE);
this.reason = reason;
}

public RoutingInfo(StreamInput in) throws IOException {
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
this.currentAllocations = in.readVInt();
this.targetAllocations = in.readVInt();
} else {
this.currentAllocations = 0;
this.targetAllocations = 0;
}
this.state = in.readEnum(RoutingState.class);
this.reason = in.readOptionalString();
}

public int getCurrentAllocations() {
return currentAllocations;
}

public int getTargetAllocations() {
return targetAllocations;
}

public RoutingState getState() {
return state;
}

@Nullable
public String getReason() {
return reason;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeVInt(currentAllocations);
out.writeVInt(targetAllocations);
}
out.writeEnum(state);
out.writeOptionalString(reason);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CURRENT_ALLOCATIONS.getPreferredName(), currentAllocations);
builder.field(TARGET_ALLOCATIONS.getPreferredName(), targetAllocations);
builder.field(ROUTING_STATE.getPreferredName(), state);
if (reason != null) {
builder.field(REASON.getPreferredName(), reason);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RoutingInfo that = (RoutingInfo) o;
return currentAllocations == that.currentAllocations
&& targetAllocations == that.targetAllocations
&& state == that.state
&& Objects.equals(reason, that.reason);
}

@Override
public int hashCode() {
return Objects.hash(currentAllocations, targetAllocations, state, reason);
}

@Override
public String toString() {
return "RoutingInfo{"
+ "current_allocations="
+ currentAllocations
+ ", target_allocations="
+ targetAllocations
+ ", reason='"
+ reason
+ '\''
+ ", state="
+ state
+ '}';
}

public boolean isRoutable() {
return state == RoutingState.STARTED && currentAllocations > 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.assignment;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class RoutingInfoUpdate implements Writeable {

private final Optional<Integer> numberOfAllocations;
private final Optional<RoutingStateAndReason> stateAndReason;

public static RoutingInfoUpdate updateNumberOfAllocations(int numberOfAllocations) {
return new RoutingInfoUpdate(Optional.of(numberOfAllocations), Optional.empty());
}

public static RoutingInfoUpdate updateStateAndReason(RoutingStateAndReason routingStateAndReason) {
return new RoutingInfoUpdate(Optional.empty(), Optional.of(routingStateAndReason));
}

private RoutingInfoUpdate(Optional<Integer> numberOfAllocations, Optional<RoutingStateAndReason> stateAndReason) {
this.numberOfAllocations = Objects.requireNonNull(numberOfAllocations);
this.stateAndReason = Objects.requireNonNull(stateAndReason);
}

public RoutingInfoUpdate(StreamInput in) throws IOException {
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
numberOfAllocations = Optional.ofNullable(in.readOptionalVInt());
stateAndReason = Optional.ofNullable(in.readOptionalWriteable(RoutingStateAndReason::new));
} else {
numberOfAllocations = Optional.empty();
stateAndReason = Optional.of(new RoutingStateAndReason(in));
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeOptionalVInt(numberOfAllocations.orElse(null));
out.writeOptionalWriteable(stateAndReason.orElse(null));
} else {
assert stateAndReason.isPresent() : "updating routing info while nodes prior to 8.4.0 should only contain state and reason";
stateAndReason.get().writeTo(out);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RoutingInfoUpdate that = (RoutingInfoUpdate) o;
return Objects.equals(numberOfAllocations, that.numberOfAllocations) && Objects.equals(stateAndReason, that.stateAndReason);
}

@Override
public int hashCode() {
return Objects.hash(numberOfAllocations, stateAndReason);
}

@Override
public String toString() {
return "RoutingInfoUpdate{" + "numberOfAllocations=" + numberOfAllocations + ", stateAndReason=" + stateAndReason + '}';
}

public Optional<Integer> getNumberOfAllocations() {
return numberOfAllocations;
}

public Optional<RoutingStateAndReason> getStateAndReason() {
return stateAndReason;
}

public RoutingInfo apply(RoutingInfo routingInfo) {
int currentAllocations = numberOfAllocations.orElse(routingInfo.getCurrentAllocations());
RoutingState state = routingInfo.getState();
String reason = routingInfo.getReason();
if (stateAndReason.isPresent()) {
state = stateAndReason.get().getState();
reason = stateAndReason.get().getReason();
}
return new RoutingInfo(currentAllocations, routingInfo.getTargetAllocations(), state, reason);
}
}

0 comments on commit a9d1fa2

Please sign in to comment.