Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] adding new PUT trained model vocabulary endpoint #77387

Merged
merged 6 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/reference/ml/df-analytics/apis/index.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include::put-dfanalytics.asciidoc[leveloffset=+2]
include::put-trained-models-aliases.asciidoc[leveloffset=+2]
include::put-trained-models.asciidoc[leveloffset=+2]
include::put-trained-model-definition-part.asciidoc[leveloffset=+2]
include::put-trained-model-vocabulary.asciidoc[leveloffset=+2]
//UPDATE
include::update-dfanalytics.asciidoc[leveloffset=+2]
//DELETE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ You can use the following APIs to perform {infer} operations:

* <<put-trained-models>>
* <<put-trained-model-definition-part>>
* <<put-trained-model-vocabulary>>
* <<put-trained-models-aliases>>
* <<delete-trained-models>>
* <<delete-trained-models-aliases>>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
[role="xpack"]
[testenv="basic"]
[[put-trained-model-vocabulary]]
= Create trained model vocabulary API
[subs="attributes"]
++++
<titleabbrev>Create trained model vocabulary</titleabbrev>
++++

Creates a trained model vocabulary.
This is only supported on NLP type models.

experimental::[]
benwtrent marked this conversation as resolved.
Show resolved Hide resolved

[[ml-put-trained-model-vocabulary-request]]
== {api-request-title}

`PUT _ml/trained_models/<model_id>/vocabulary/`


[[ml-put-trained-model-vocabulary-prereq]]
== {api-prereq-title}

Requires the `manage_ml` cluster privilege. This privilege is included in the
`machine_learning_admin` built-in role.


[[ml-put-trained-model-vocabulary-path-params]]
== {api-path-parms-title}

`<model_id>`::
(Required, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]

[[ml-put-trained-model-vocabulary-request-body]]
== {api-request-body-title}

`vocabulary`::
(array)
The model vocabulary. Must not be empty.

////
[[ml-put-trained-model-vocabulary-example]]
== {api-examples-title}
////
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"ml.put_trained_model_vocabulary":{
"documentation":{
"url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/put-trained-model-vocabulary.html",
"description":"Creates a trained model vocabulary"
},
"stability":"experimental",
"visibility":"public",
"headers":{
"accept": [ "application/json"],
"content_type": ["application/json"]
},
"url":{
"paths":[
{
"path":"/_ml/trained_models/{model_id}/vocabulary",
"methods":[
"PUT"
],
"parts":{
"model_id":{
"type":"string",
"description":"The ID of the trained model for this vocabulary"
}
}
}
]
},
"body":{
"description":"The trained model vocabulary",
"required":true
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.action.ValidateActions.addValidationError;

public class PutTrainedModelVocabularyAction extends ActionType<AcknowledgedResponse> {

public static final PutTrainedModelVocabularyAction INSTANCE = new PutTrainedModelVocabularyAction();
public static final String NAME = "cluster:admin/xpack/ml/trained_models/vocabulary/put";

private PutTrainedModelVocabularyAction() {
super(NAME, AcknowledgedResponse::readFrom);
}

public static class Request extends AcknowledgedRequest<Request> {

public static final ParseField VOCABULARY = new ParseField("vocabulary");

private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
"put_trained_model_vocabulary",
Builder::new
);
static {
PARSER.declareStringArray(Builder::setVocabulary, VOCABULARY);
}

public static Request parseRequest(String modelId, XContentParser parser) {
return PARSER.apply(parser, null).build(modelId);
}

private final String modelId;
private final List<String> vocabulary;

public Request(String modelId, List<String> vocabulary) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
this.vocabulary = ExceptionsHelper.requireNonNull(vocabulary, VOCABULARY);
}

public Request(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.vocabulary = in.readStringList();
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException validationException = null;
if (vocabulary.isEmpty()) {
validationException = addValidationError("[vocabulary] must not be empty", validationException);
}
return validationException;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(modelId, request.modelId)
&& Objects.equals(vocabulary, request.vocabulary);
}

@Override
public int hashCode() {
return Objects.hash(modelId, vocabulary);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeStringCollection(vocabulary);
}

public String getModelId() {
return modelId;
}

public List<String> getVocabulary() {
return vocabulary;
}

public static class Builder {
private List<String> vocabulary;

public Builder setVocabulary(List<String> vocabulary) {
this.vocabulary = vocabulary;
return this;
}

public Request build(String modelId) {
return new Request(modelId, vocabulary);
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class FillMaskConfig implements NlpConfig {

Expand All @@ -38,7 +40,19 @@ public static FillMaskConfig fromXContentLenient(XContentParser parser) {
private static ConstructingObjectParser<FillMaskConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<FillMaskConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
a -> new FillMaskConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
parser.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
if (ignoreUnknownFields == false) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ the caller can't set this but the code creates the config so it is visible

throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
},
VOCABULARY
);
parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION
Expand All @@ -49,8 +63,9 @@ private static ConstructingObjectParser<FillMaskConfig, Void> createParser(boole
private final VocabularyConfig vocabularyConfig;
private final Tokenization tokenization;

public FillMaskConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
public FillMaskConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
}

Expand All @@ -62,7 +77,7 @@ public FillMaskConfig(StreamInput in) throws IOException {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
builder.endObject();
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class NerConfig implements NlpConfig {

Expand All @@ -41,7 +43,19 @@ public static NerConfig fromXContentLenient(XContentParser parser) {
private static ConstructingObjectParser<NerConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<NerConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
a -> new NerConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2]));
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
parser.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
if (ignoreUnknownFields == false) {
throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
},
VOCABULARY
);
parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION
Expand All @@ -54,10 +68,11 @@ private static ConstructingObjectParser<NerConfig, Void> createParser(boolean ig
private final Tokenization tokenization;
private final List<String> classificationLabels;

public NerConfig(VocabularyConfig vocabularyConfig,
public NerConfig(@Nullable VocabularyConfig vocabularyConfig,
@Nullable Tokenization tokenization,
@Nullable List<String> classificationLabels) {
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
}
Expand All @@ -78,7 +93,7 @@ public void writeTo(StreamOutput out) throws IOException {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
if (classificationLabels.isEmpty() == false) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class PassThroughConfig implements NlpConfig {

Expand All @@ -38,7 +40,19 @@ public static PassThroughConfig fromXContentLenient(XContentParser parser) {
private static ConstructingObjectParser<PassThroughConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<PassThroughConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
a -> new PassThroughConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
parser.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
if (ignoreUnknownFields == false) {
throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
},
VOCABULARY
);
parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION
Expand All @@ -49,8 +63,9 @@ private static ConstructingObjectParser<PassThroughConfig, Void> createParser(bo
private final VocabularyConfig vocabularyConfig;
private final Tokenization tokenization;

public PassThroughConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
public PassThroughConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
}

Expand All @@ -62,7 +77,7 @@ public PassThroughConfig(StreamInput in) throws IOException {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
builder.endObject();
return builder;
Expand Down
Loading