Skip to content

Commit

Permalink
[ML] Add Deployment Id for NLP deployments (#95053)
Browse files Browse the repository at this point in the history
Refactoring and renaming usages of model id to deployment id
  • Loading branch information
davidkyle committed Apr 11, 2023
1 parent b5e2222 commit d704f21
Show file tree
Hide file tree
Showing 79 changed files with 1,680 additions and 1,187 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ public static String dataFrameAnalyticsId(String taskId) {
return taskId.substring(DATA_FRAME_ANALYTICS_TASK_ID_PREFIX.length());
}

public static String trainedModelAssignmentTaskDescription(String modelId) {
return TrainedModelConfig.MODEL_ID.getPreferredName() + "[" + modelId + "]";
public static String trainedModelAssignmentTaskDescription(String deploymentId) {
// A description containing deployment_id[XXX] is more accurate
// than model_id[XXX] but the legacy description cannot be changed now
return TrainedModelConfig.MODEL_ID.getPreferredName() + "[" + deploymentId + "]";
}

@Nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,43 +31,43 @@ private ClearDeploymentCacheAction() {
}

public static class Request extends BaseTasksRequest<Request> {
private final String modelId;
private final String deploymentId;

public Request(String modelId) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, InferModelAction.Request.MODEL_ID);
public Request(String deploymentId) {
this.deploymentId = ExceptionsHelper.requireNonNull(deploymentId, InferModelAction.Request.ID);
}

public Request(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.deploymentId = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeString(deploymentId);
}

public String getModelId() {
return modelId;
public String getDeploymentId() {
return deploymentId;
}

@Override
public boolean match(Task task) {
return StartTrainedModelDeploymentAction.TaskMatcher.match(task, modelId);
return StartTrainedModelDeploymentAction.TaskMatcher.match(task, deploymentId);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(modelId, request.modelId);
return Objects.equals(deploymentId, request.deploymentId);
}

@Override
public int hashCode() {
return Objects.hash(modelId);
return Objects.hash(deploymentId);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,30 @@ private GetDeploymentStatsAction() {

public static class Request extends BaseTasksRequest<GetDeploymentStatsAction.Request> {

private final String modelId;
private final String deploymentId;
// used internally this should not be set by the REST request
private List<String> expandedIds;

public Request(String modelId) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, InferModelAction.Request.MODEL_ID);
this.expandedIds = Collections.singletonList(modelId);
public Request(String deploymentId) {
this.deploymentId = ExceptionsHelper.requireNonNull(deploymentId, InferModelAction.Request.DEPLOYMENT_ID);
this.expandedIds = Collections.singletonList(deploymentId);
}

public Request(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.deploymentId = in.readString();
this.expandedIds = in.readStringList();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeString(deploymentId);
out.writeStringCollection(expandedIds);
}

public String getModelId() {
return modelId;
public String getDeploymentId() {
return deploymentId;
}

public void setExpandedIds(List<String> expandedIds) {
Expand All @@ -78,12 +78,12 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(modelId, request.modelId) && Objects.equals(expandedIds, request.expandedIds);
return Objects.equals(deploymentId, request.deploymentId) && Objects.equals(expandedIds, request.expandedIds);
}

@Override
public int hashCode() {
return Objects.hash(modelId, expandedIds);
return Objects.hash(deploymentId, expandedIds);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ private InferModelAction(String name) {

public static class Request extends ActionRequest {

public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField ID = new ParseField("id");
public static final ParseField DEPLOYMENT_ID = new ParseField("deployment_id");
public static final ParseField DOCS = new ParseField("docs");
public static final ParseField TIMEOUT = new ParseField("timeout");
public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");

static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME, Builder::new);
static {
PARSER.declareString(Builder::setModelId, MODEL_ID);
PARSER.declareString(Builder::setId, ID);
PARSER.declareObjectArray(Builder::setDocs, (p, c) -> p.mapOrdered(), DOCS);
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
PARSER.declareNamedObject(
Expand All @@ -66,18 +67,18 @@ public static class Request extends ActionRequest {
);
}

public static Builder parseRequest(String modelId, XContentParser parser) {
public static Builder parseRequest(String id, XContentParser parser) {
Builder builder = PARSER.apply(parser, null);
if (modelId != null) {
builder.setModelId(modelId);
if (id != null) {
builder.setId(id);
}
return builder;
}

public static final TimeValue DEFAULT_TIMEOUT_FOR_API = TimeValue.timeValueSeconds(10);
public static final TimeValue DEFAULT_TIMEOUT_FOR_INGEST = TimeValue.MAX_VALUE;

private final String modelId;
private final String id;
private final List<Map<String, Object>> objectsToInfer;
private final InferenceConfigUpdate update;
private final boolean previouslyLicensed;
Expand All @@ -95,13 +96,13 @@ public static Builder parseRequest(String modelId, XContentParser parser) {
* to prefer slow ingest over dropping documents.
*/
public static Request forIngestDocs(
String modelId,
String id,
List<Map<String, Object>> docs,
InferenceConfigUpdate update,
boolean previouslyLicensed
) {
return new Request(
ExceptionsHelper.requireNonNull(modelId, InferModelAction.Request.MODEL_ID),
ExceptionsHelper.requireNonNull(id, InferModelAction.Request.ID),
update,
ExceptionsHelper.requireNonNull(Collections.unmodifiableList(docs), DOCS),
null,
Expand All @@ -116,9 +117,9 @@ public static Request forIngestDocs(
* The inference timeout (how long the request waits in
* the inference queue for) is set to {@code #DEFAULT_TIMEOUT_FOR_API}
*/
public static Request forTextInput(String modelId, InferenceConfigUpdate update, List<String> textInput) {
public static Request forTextInput(String id, InferenceConfigUpdate update, List<String> textInput) {
return new Request(
modelId,
id,
update,
List.of(),
ExceptionsHelper.requireNonNull(textInput, "inference text input"),
Expand All @@ -128,14 +129,14 @@ public static Request forTextInput(String modelId, InferenceConfigUpdate update,
}

Request(
String modelId,
String id,
InferenceConfigUpdate inferenceConfigUpdate,
List<Map<String, Object>> docs,
List<String> textInput,
TimeValue inferenceTimeout,
boolean previouslyLicensed
) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.id = ExceptionsHelper.requireNonNull(id, ID);
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(docs, DOCS.getPreferredName()));
this.update = ExceptionsHelper.requireNonNull(inferenceConfigUpdate, "inference_config");
this.textInput = textInput;
Expand All @@ -145,7 +146,7 @@ public static Request forTextInput(String modelId, InferenceConfigUpdate update,

public Request(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.id = in.readString();
this.objectsToInfer = in.readImmutableList(StreamInput::readMap);
this.update = in.readNamedWriteable(InferenceConfigUpdate.class);
this.previouslyLicensed = in.readBoolean();
Expand All @@ -172,8 +173,8 @@ public int numberOfDocuments() {
}
}

public String getModelId() {
return modelId;
public String getId() {
return id;
}

public List<Map<String, Object>> getObjectsToInfer() {
Expand Down Expand Up @@ -217,7 +218,7 @@ public ActionRequestValidationException validate() {
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeString(id);
out.writeCollection(objectsToInfer, StreamOutput::writeGenericMap);
out.writeNamedWriteable(update);
out.writeBoolean(previouslyLicensed);
Expand All @@ -237,7 +238,7 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferModelAction.Request that = (InferModelAction.Request) o;
return Objects.equals(modelId, that.modelId)
return Objects.equals(id, that.id)
&& Objects.equals(update, that.update)
&& Objects.equals(previouslyLicensed, that.previouslyLicensed)
&& Objects.equals(inferenceTimeout, that.inferenceTimeout)
Expand All @@ -248,25 +249,25 @@ public boolean equals(Object o) {

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, format("infer_trained_model[%s]", modelId), parentTaskId, headers);
return new CancellableTask(id, type, action, format("infer_trained_model[%s]", this.id), parentTaskId, headers);
}

@Override
public int hashCode() {
return Objects.hash(modelId, objectsToInfer, update, previouslyLicensed, inferenceTimeout, textInput, highPriority);
return Objects.hash(id, objectsToInfer, update, previouslyLicensed, inferenceTimeout, textInput, highPriority);
}

public static class Builder {

private String modelId;
private String id;
private List<Map<String, Object>> docs;
private TimeValue timeout;
private InferenceConfigUpdate update = new EmptyConfigUpdate();

private Builder() {}

public Builder setModelId(String modelId) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
public Builder setId(String id) {
this.id = ExceptionsHelper.requireNonNull(id, ID);
return this;
}

Expand Down Expand Up @@ -294,7 +295,7 @@ private Builder setInferenceTimeout(String inferenceTimeout) {
}

public Request build() {
return new Request(modelId, update, docs, null, timeout, false);
return new Request(id, update, docs, null, timeout, false);
}
}

Expand All @@ -303,21 +304,21 @@ public Request build() {
public static class Response extends ActionResponse implements ToXContentObject {

private final List<InferenceResults> inferenceResults;
private final String modelId;
private final String id;
private final boolean isLicensed;

public Response(List<InferenceResults> inferenceResults, String modelId, boolean isLicensed) {
public Response(List<InferenceResults> inferenceResults, String id, boolean isLicensed) {
super();
this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults"));
this.isLicensed = isLicensed;
this.modelId = modelId;
this.id = id;
}

public Response(StreamInput in) throws IOException {
super(in);
this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class));
this.isLicensed = in.readBoolean();
this.modelId = in.readOptionalString();
this.id = in.readOptionalString();
}

public List<InferenceResults> getInferenceResults() {
Expand All @@ -328,30 +329,28 @@ public boolean isLicensed() {
return isLicensed;
}

public String getModelId() {
return modelId;
public String getId() {
return id;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteableList(inferenceResults);
out.writeBoolean(isLicensed);
out.writeOptionalString(modelId);
out.writeOptionalString(id);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferModelAction.Response that = (InferModelAction.Response) o;
return isLicensed == that.isLicensed
&& Objects.equals(inferenceResults, that.inferenceResults)
&& Objects.equals(modelId, that.modelId);
return isLicensed == that.isLicensed && Objects.equals(inferenceResults, that.inferenceResults) && Objects.equals(id, that.id);
}

@Override
public int hashCode() {
return Objects.hash(inferenceResults, isLicensed, modelId);
return Objects.hash(inferenceResults, isLicensed, id);
}

public static Builder builder() {
Expand All @@ -375,7 +374,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

public static class Builder {
private List<InferenceResults> inferenceResults = new ArrayList<>();
private String modelId;
private String id;
private boolean isLicensed;

public Builder addInferenceResults(List<InferenceResults> inferenceResults) {
Expand All @@ -388,13 +387,13 @@ public Builder setLicensed(boolean licensed) {
return this;
}

public Builder setModelId(String modelId) {
this.modelId = modelId;
public Builder setId(String id) {
this.id = id;
return this;
}

public Response build() {
return new Response(inferenceResults, modelId, isLicensed);
return new Response(inferenceResults, id, isLicensed);
}
}

Expand Down

0 comments on commit d704f21

Please sign in to comment.