Skip to content

Commit

Permalink
[ML] add model downloader x-pack core changes (#95175)
Browse files Browse the repository at this point in the history
This PR introduces x-pack core changes in order to add 2 new internal actions for downloading and installing prepackaged models. The implementations of these actions are to be added in separate PR's.
  • Loading branch information
Hendrik Muhs committed Apr 12, 2023
1 parent 46f5ee4 commit f8c72ae
Show file tree
Hide file tree
Showing 12 changed files with 1,071 additions and 33 deletions.
1 change: 1 addition & 0 deletions x-pack/plugin/core/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
exports org.elasticsearch.xpack.core.ml.job.results;
exports org.elasticsearch.xpack.core.ml.job.snapshot.upgrade;
exports org.elasticsearch.xpack.core.ml.notifications;
exports org.elasticsearch.xpack.core.ml.packageloader.action;
exports org.elasticsearch.xpack.core.ml.process.writer;
exports org.elasticsearch.xpack.core.ml.stats;
exports org.elasticsearch.xpack.core.ml.utils.time;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
Expand Down Expand Up @@ -36,7 +37,12 @@ private PutTrainedModelAction() {

public static class Request extends AcknowledgedRequest<Request> {

public static Request parseRequest(String modelId, boolean deferDefinitionValidation, XContentParser parser) {
public static Request parseRequest(
String modelId,
boolean deferDefinitionValidation,
boolean waitForCompletion,
XContentParser parser
) {
TrainedModelConfig.Builder builder = TrainedModelConfig.STRICT_PARSER.apply(parser, null);

if (builder.getModelId() == null) {
Expand All @@ -54,21 +60,33 @@ public static Request parseRequest(String modelId, boolean deferDefinitionValida
}
// Validations are done against the builder so we can build the full config object.
// This allows us to not worry about serializing a builder class between nodes.
return new Request(builder.validate(true).build(), deferDefinitionValidation);
return new Request(builder.validate(true).build(), deferDefinitionValidation, waitForCompletion);
}

private final TrainedModelConfig config;
private final boolean deferDefinitionDecompression;
private final boolean waitForCompletion;

// TODO: remove this constructor after re-factoring ML parts
public Request(TrainedModelConfig config, boolean deferDefinitionDecompression) {
this(config, deferDefinitionDecompression, false);
}

public Request(TrainedModelConfig config, boolean deferDefinitionDecompression, boolean waitForCompletion) {
this.config = config;
this.deferDefinitionDecompression = deferDefinitionDecompression;
this.waitForCompletion = waitForCompletion;
}

public Request(StreamInput in) throws IOException {
super(in);
this.config = new TrainedModelConfig(in);
this.deferDefinitionDecompression = in.readBoolean();
if (in.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
this.waitForCompletion = in.readBoolean();
} else {
this.waitForCompletion = false;
}
}

public TrainedModelConfig getTrainedModelConfig() {
Expand All @@ -95,24 +113,33 @@ public boolean isDeferDefinitionDecompression() {
return deferDefinitionDecompression;
}

public boolean isWaitForCompletion() {
return waitForCompletion;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
config.writeTo(out);
out.writeBoolean(deferDefinitionDecompression);
if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
out.writeBoolean(waitForCompletion);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(config, request.config) && deferDefinitionDecompression == request.deferDefinitionDecompression;
return Objects.equals(config, request.config)
&& deferDefinitionDecompression == request.deferDefinitionDecompression
&& waitForCompletion == request.waitForCompletion;
}

@Override
public int hashCode() {
return Objects.hash(config, deferDefinitionDecompression);
return Objects.hash(config, deferDefinitionDecompression, waitForCompletion);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModelLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModelLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
Expand Down Expand Up @@ -96,6 +97,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map");
public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
public static final ParseField LOCATION = new ParseField("location");
public static final ParseField MODEL_PACKAGE = new ParseField("model_package");

public static final TransportVersion VERSION_3RD_PARTY_CONFIG_ADDED = TransportVersion.V_8_0_0;

Expand Down Expand Up @@ -155,6 +157,11 @@ private static ObjectParser<TrainedModelConfig.Builder, Void> createParser(boole
: p.namedObject(StrictlyParsedTrainedModelLocation.class, n, null),
LOCATION
);
parser.declareObject(
TrainedModelConfig.Builder::setModelPackageConfig,
(p, c) -> ignoreUnknownFields ? ModelPackageConfig.fromXContentLenient(p) : ModelPackageConfig.fromXContentStrict(p),
MODEL_PACKAGE
);
return parser;
}

Expand All @@ -179,6 +186,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo

private final LazyModelDefinition definition;
private final TrainedModelLocation location;
private final ModelPackageConfig modelPackageConfig;

TrainedModelConfig(
String modelId,
Expand All @@ -196,7 +204,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
String licenseLevel,
Map<String, String> defaultFieldMap,
InferenceConfig inferenceConfig,
TrainedModelLocation location
TrainedModelLocation location,
ModelPackageConfig modelPackageConfig
) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.modelType = modelType;
Expand All @@ -222,6 +231,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap);
this.inferenceConfig = inferenceConfig;
this.location = location;
this.modelPackageConfig = modelPackageConfig;
}

private static TrainedModelInput handleDefaultInput(TrainedModelInput input, TrainedModelType modelType) {
Expand Down Expand Up @@ -254,6 +264,15 @@ public TrainedModelConfig(StreamInput in) throws IOException {
this.modelType = null;
this.location = null;
}
if (in.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
modelPackageConfig = in.readOptionalWriteable(ModelPackageConfig::new);
} else {
modelPackageConfig = null;
}
}

public boolean isPackagedModel() {
return modelId.startsWith(".");
}

public String getModelId() {
Expand Down Expand Up @@ -313,6 +332,10 @@ public BytesReference getCompressedDefinitionIfSet() {
return definition.getCompressedDefinitionIfSet();
}

public ModelPackageConfig getModelPackageConfig() {
return modelPackageConfig;
}

public void clearCompressed() {
definition.compressedRepresentation = null;
}
Expand Down Expand Up @@ -397,6 +420,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(modelType);
out.writeOptionalNamedWriteable(location);
}

if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
out.writeOptionalWriteable(modelPackageConfig);
}
}

@Override
Expand All @@ -406,6 +433,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (modelType != null) {
builder.field(MODEL_TYPE.getPreferredName(), modelType.toString());
}
if (modelPackageConfig != null) {
builder.field(MODEL_PACKAGE.getPreferredName(), modelPackageConfig);
}

// If the model is to be exported for future import to another cluster, these fields are irrelevant.
if (params.paramAsBoolean(EXCLUDE_GENERATED, false) == false) {
builder.field(CREATED_BY.getPreferredName(), createdBy);
Expand Down Expand Up @@ -468,6 +499,7 @@ public boolean equals(Object o) {
TrainedModelConfig that = (TrainedModelConfig) o;
return Objects.equals(modelId, that.modelId)
&& Objects.equals(modelType, that.modelType)
&& Objects.equals(modelPackageConfig, that.modelPackageConfig)
&& Objects.equals(createdBy, that.createdBy)
&& Objects.equals(version, that.version)
&& Objects.equals(description, that.description)
Expand All @@ -489,6 +521,7 @@ public int hashCode() {
return Objects.hash(
modelId,
modelType,
modelPackageConfig,
createdBy,
version,
createTime,
Expand Down Expand Up @@ -524,6 +557,7 @@ public static class Builder {
private Map<String, String> defaultFieldMap;
private InferenceConfig inferenceConfig;
private TrainedModelLocation location;
private ModelPackageConfig modelPackageConfig;

public Builder() {}

Expand All @@ -544,6 +578,7 @@ public Builder(TrainedModelConfig config) {
this.defaultFieldMap = config.defaultFieldMap == null ? null : new HashMap<>(config.defaultFieldMap);
this.inferenceConfig = config.inferenceConfig;
this.location = config.location;
this.modelPackageConfig = config.modelPackageConfig;
}

public Builder setModelId(String modelId) {
Expand All @@ -569,6 +604,11 @@ public String getModelId() {
return this.modelId;
}

public Builder setModelPackageConfig(ModelPackageConfig modelPackageConfig) {
this.modelPackageConfig = modelPackageConfig;
return this;
}

public Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
Expand Down Expand Up @@ -743,34 +783,46 @@ public Builder validate() {
public Builder validate(boolean forCreation) {
// We require a definition to be available here even though it will be stored in a different doc
ActionRequestValidationException validationException = null;
if (definition != null && location != null) {
validationException = addValidationError(
"["
+ DEFINITION.getPreferredName()
+ "] "
+ "and ["
+ LOCATION.getPreferredName()
+ "] are both defined but only one can be used.",
validationException
);
}
if (definition == null && modelType == null) {
validationException = addValidationError(
"[" + MODEL_TYPE.getPreferredName() + "] must be set if " + "[" + DEFINITION.getPreferredName() + "] is not defined.",
validationException
);
}
boolean packagedModel = modelId != null && modelId.startsWith(".");

if (modelId == null) {
validationException = addValidationError("[" + MODEL_ID.getPreferredName() + "] must not be null.", validationException);
}
if (inferenceConfig == null && forCreation) {
validationException = addValidationError(
"[" + INFERENCE_CONFIG.getPreferredName() + "] must not be null.",
validationException
);
}

if (modelId != null && MlStrings.isValidId(modelId) == false) {
if (packagedModel == false) {
if (definition != null && location != null) {
validationException = addValidationError(
"["
+ DEFINITION.getPreferredName()
+ "] "
+ "and ["
+ LOCATION.getPreferredName()
+ "] are both defined but only one can be used.",
validationException
);
}
if (definition == null && modelType == null) {
validationException = addValidationError(
"["
+ MODEL_TYPE.getPreferredName()
+ "] must be set if "
+ "["
+ DEFINITION.getPreferredName()
+ "] is not defined.",
validationException
);
}

if (inferenceConfig == null && forCreation) {
validationException = addValidationError(
"[" + INFERENCE_CONFIG.getPreferredName() + "] must not be null.",
validationException
);
}
}
if (modelId != null && packagedModel
? MlStrings.isValidId(modelId.substring(1)) // packaged models
: MlStrings.isValidId(modelId) == false) {
validationException = addValidationError(
Messages.getMessage(Messages.INVALID_ID, TrainedModelConfig.MODEL_ID.getPreferredName(), modelId),
validationException
Expand Down Expand Up @@ -835,7 +887,41 @@ public Builder validate(boolean forCreation) {
validationException
);
}

// packaged model validation
validationException = checkIllegalSetting(modelPackageConfig, MODEL_PACKAGE.getPreferredName(), validationException);
}
if (validationException != null) {
throw validationException;
}

return this;
}

/**
* Validate that fields defined by the package aren't defined in the request.
*
* To be called by the transport after checking that the package exists.
*/
public Builder validateNoPackageOverrides() {
ActionRequestValidationException validationException = null;
validationException = checkIllegalPackagedModelSetting(description, DESCRIPTION.getPreferredName(), validationException);
validationException = checkIllegalPackagedModelSetting(definition, DEFINITION.getPreferredName(), validationException);
validationException = checkIllegalPackagedModelSetting(modelType, MODEL_TYPE.getPreferredName(), validationException);
validationException = checkIllegalPackagedModelSetting(metadata, METADATA.getPreferredName(), validationException);
validationException = checkIllegalPackagedModelSetting(modelSize, MODEL_SIZE_BYTES.getPreferredName(), validationException);
validationException = checkIllegalPackagedModelSetting(
inferenceConfig,
INFERENCE_CONFIG.getPreferredName(),
validationException
);
if (tags != null && tags.isEmpty() == false) {
validationException = addValidationError(
"illegal to set [tags] at inference model creation for packaged model",
validationException
);
}

if (validationException != null) {
throw validationException;
}
Expand All @@ -854,6 +940,20 @@ private static ActionRequestValidationException checkIllegalSetting(
return validationException;
}

private static ActionRequestValidationException checkIllegalPackagedModelSetting(
Object value,
String setting,
ActionRequestValidationException validationException
) {
if (value != null) {
return addValidationError(
"illegal to set [" + setting + "] at inference model creation for packaged model",
validationException
);
}
return validationException;
}

public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
Expand All @@ -871,7 +971,8 @@ public TrainedModelConfig build() {
licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel,
defaultFieldMap,
inferenceConfig,
location
location,
modelPackageConfig
);
}
}
Expand Down

0 comments on commit f8c72ae

Please sign in to comment.