-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML] Improve scalability of NLP models (#87366)
We can improve a model's latency and throughput by using more threads for each inference request. We can also improve its throughput by processing multiple inference requests in parallel. A user can set the model's `number_of_allocations` and `threads_per_allocation` settings to increase performance. However, if we use more threads than the node's allocated processors, we end up with thread oversubscription and performance deteriorates throughout. This commit changes the way model allocations are distributed across the ML nodes of the cluster. Up to now, we were trying to allocate each model on every node. Now, we make use of the `AssignmentPlanner` class introduced in #86004 in order to compute an assignment plan that distributes model allocations across the cluster while we maximize the number of allocations we provide to each model without oversubscribing the nodes' CPU.
- Loading branch information
1 parent
f98a02d
commit a9d1fa2
Showing
50 changed files
with
2,610 additions
and
1,055 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 87366 | ||
summary: Improve scalability of NLP models | ||
area: Machine Learning | ||
type: enhancement | ||
issues: [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
155 changes: 155 additions & 0 deletions
155
.../core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfo.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.inference.assignment; | ||
|
||
import org.elasticsearch.Version; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.common.io.stream.Writeable; | ||
import org.elasticsearch.core.Nullable; | ||
import org.elasticsearch.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.xcontent.ParseField; | ||
import org.elasticsearch.xcontent.ToXContentObject; | ||
import org.elasticsearch.xcontent.XContentBuilder; | ||
import org.elasticsearch.xcontent.XContentParser; | ||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
|
||
public class RoutingInfo implements ToXContentObject, Writeable { | ||
|
||
private static final ParseField CURRENT_ALLOCATIONS = new ParseField("current_allocations"); | ||
private static final ParseField TARGET_ALLOCATIONS = new ParseField("target_allocations"); | ||
private static final ParseField ROUTING_STATE = new ParseField("routing_state"); | ||
private static final ParseField REASON = new ParseField("reason"); | ||
|
||
private static final ConstructingObjectParser<RoutingInfo, Void> PARSER = new ConstructingObjectParser<>( | ||
"trained_model_routing_state", | ||
a -> new RoutingInfo((Integer) a[0], (Integer) a[1], RoutingState.fromString((String) a[2]), (String) a[3]) | ||
); | ||
static { | ||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), CURRENT_ALLOCATIONS); | ||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), TARGET_ALLOCATIONS); | ||
PARSER.declareString(ConstructingObjectParser.constructorArg(), ROUTING_STATE); | ||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON); | ||
} | ||
|
||
public static RoutingInfo fromXContent(XContentParser parser) { | ||
return PARSER.apply(parser, null); | ||
} | ||
|
||
private final int currentAllocations; | ||
private final int targetAllocations; | ||
private final RoutingState state; | ||
private final String reason; | ||
|
||
// There may be objects in cluster state prior to 8.4 that do not contain values for currentAllocations and targetAllocations. | ||
private RoutingInfo( | ||
@Nullable Integer currentAllocations, | ||
@Nullable Integer targetAllocations, | ||
RoutingState state, | ||
@Nullable String reason | ||
) { | ||
this(currentAllocations == null ? 0 : currentAllocations, targetAllocations == null ? 0 : targetAllocations, state, reason); | ||
} | ||
|
||
public RoutingInfo(int currentAllocations, int targetAllocations, RoutingState state, String reason) { | ||
this.currentAllocations = currentAllocations; | ||
this.targetAllocations = targetAllocations; | ||
this.state = ExceptionsHelper.requireNonNull(state, ROUTING_STATE); | ||
this.reason = reason; | ||
} | ||
|
||
public RoutingInfo(StreamInput in) throws IOException { | ||
if (in.getVersion().onOrAfter(Version.V_8_4_0)) { | ||
this.currentAllocations = in.readVInt(); | ||
this.targetAllocations = in.readVInt(); | ||
} else { | ||
this.currentAllocations = 0; | ||
this.targetAllocations = 0; | ||
} | ||
this.state = in.readEnum(RoutingState.class); | ||
this.reason = in.readOptionalString(); | ||
} | ||
|
||
public int getCurrentAllocations() { | ||
return currentAllocations; | ||
} | ||
|
||
public int getTargetAllocations() { | ||
return targetAllocations; | ||
} | ||
|
||
public RoutingState getState() { | ||
return state; | ||
} | ||
|
||
@Nullable | ||
public String getReason() { | ||
return reason; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
if (out.getVersion().onOrAfter(Version.V_8_4_0)) { | ||
out.writeVInt(currentAllocations); | ||
out.writeVInt(targetAllocations); | ||
} | ||
out.writeEnum(state); | ||
out.writeOptionalString(reason); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(CURRENT_ALLOCATIONS.getPreferredName(), currentAllocations); | ||
builder.field(TARGET_ALLOCATIONS.getPreferredName(), targetAllocations); | ||
builder.field(ROUTING_STATE.getPreferredName(), state); | ||
if (reason != null) { | ||
builder.field(REASON.getPreferredName(), reason); | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
RoutingInfo that = (RoutingInfo) o; | ||
return currentAllocations == that.currentAllocations | ||
&& targetAllocations == that.targetAllocations | ||
&& state == that.state | ||
&& Objects.equals(reason, that.reason); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(currentAllocations, targetAllocations, state, reason); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return "RoutingInfo{" | ||
+ "current_allocations=" | ||
+ currentAllocations | ||
+ ", target_allocations=" | ||
+ targetAllocations | ||
+ ", reason='" | ||
+ reason | ||
+ '\'' | ||
+ ", state=" | ||
+ state | ||
+ '}'; | ||
} | ||
|
||
public boolean isRoutable() { | ||
return state == RoutingState.STARTED && currentAllocations > 0; | ||
} | ||
} |
94 changes: 94 additions & 0 deletions
94
...src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoUpdate.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.inference.assignment; | ||
|
||
import org.elasticsearch.Version; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.common.io.stream.Writeable; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
import java.util.Optional; | ||
|
||
public class RoutingInfoUpdate implements Writeable { | ||
|
||
private final Optional<Integer> numberOfAllocations; | ||
private final Optional<RoutingStateAndReason> stateAndReason; | ||
|
||
public static RoutingInfoUpdate updateNumberOfAllocations(int numberOfAllocations) { | ||
return new RoutingInfoUpdate(Optional.of(numberOfAllocations), Optional.empty()); | ||
} | ||
|
||
public static RoutingInfoUpdate updateStateAndReason(RoutingStateAndReason routingStateAndReason) { | ||
return new RoutingInfoUpdate(Optional.empty(), Optional.of(routingStateAndReason)); | ||
} | ||
|
||
private RoutingInfoUpdate(Optional<Integer> numberOfAllocations, Optional<RoutingStateAndReason> stateAndReason) { | ||
this.numberOfAllocations = Objects.requireNonNull(numberOfAllocations); | ||
this.stateAndReason = Objects.requireNonNull(stateAndReason); | ||
} | ||
|
||
public RoutingInfoUpdate(StreamInput in) throws IOException { | ||
if (in.getVersion().onOrAfter(Version.V_8_4_0)) { | ||
numberOfAllocations = Optional.ofNullable(in.readOptionalVInt()); | ||
stateAndReason = Optional.ofNullable(in.readOptionalWriteable(RoutingStateAndReason::new)); | ||
} else { | ||
numberOfAllocations = Optional.empty(); | ||
stateAndReason = Optional.of(new RoutingStateAndReason(in)); | ||
} | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
if (out.getVersion().onOrAfter(Version.V_8_4_0)) { | ||
out.writeOptionalVInt(numberOfAllocations.orElse(null)); | ||
out.writeOptionalWriteable(stateAndReason.orElse(null)); | ||
} else { | ||
assert stateAndReason.isPresent() : "updating routing info while nodes prior to 8.4.0 should only contain state and reason"; | ||
stateAndReason.get().writeTo(out); | ||
} | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
RoutingInfoUpdate that = (RoutingInfoUpdate) o; | ||
return Objects.equals(numberOfAllocations, that.numberOfAllocations) && Objects.equals(stateAndReason, that.stateAndReason); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(numberOfAllocations, stateAndReason); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return "RoutingInfoUpdate{" + "numberOfAllocations=" + numberOfAllocations + ", stateAndReason=" + stateAndReason + '}'; | ||
} | ||
|
||
public Optional<Integer> getNumberOfAllocations() { | ||
return numberOfAllocations; | ||
} | ||
|
||
public Optional<RoutingStateAndReason> getStateAndReason() { | ||
return stateAndReason; | ||
} | ||
|
||
public RoutingInfo apply(RoutingInfo routingInfo) { | ||
int currentAllocations = numberOfAllocations.orElse(routingInfo.getCurrentAllocations()); | ||
RoutingState state = routingInfo.getState(); | ||
String reason = routingInfo.getReason(); | ||
if (stateAndReason.isPresent()) { | ||
state = stateAndReason.get().getState(); | ||
reason = stateAndReason.get().getReason(); | ||
} | ||
return new RoutingInfo(currentAllocations, routingInfo.getTargetAllocations(), state, reason); | ||
} | ||
} |
Oops, something went wrong.