Skip to content

Commit

Permalink
[ML] Rationalise trained models error messages (#82054) (#82068)
Browse files Browse the repository at this point in the history
Use consistent language in the error messages from trained model inference.
# Conflicts:
#	x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
#	x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java
  • Loading branch information
davidkyle committed Dec 23, 2021
1 parent 14fd6e3 commit 942ca30
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public final class Messages {
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata {0}";
public static final String VOCABULARY_NOT_FOUND = "[{0}] Could not find vocabulary document [{1}] for model ";
public static final String VOCABULARY_NOT_FOUND = "Could not find vocabulary document [{1}] for trained model [{0}]";
public static final String INFERENCE_CANNOT_DELETE_ML_MANAGED_MODEL =
"Unable to delete model [{0}] as it is required by machine learning";
public static final String MODEL_DEFINITION_TRUNCATED =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,7 @@ public void testInferencePipelineAgainstUnallocatedModel() throws IOException {
}
""");
Exception ex = expectThrows(Exception.class, () -> client().performRequest(request));
assertThat(
ex.getMessage(),
containsString("model [not-deployed] must be deployed to use. Please deploy with the start trained model deployment API.")
);
assertThat(ex.getMessage(), containsString("Trained model [not-deployed] is not deployed."));
}

public void testStopUsedDeploymentByIngestProcessor() throws IOException {
Expand Down Expand Up @@ -553,7 +550,7 @@ public void testPipelineWithBadProcessor() throws IOException {
assertThat(
response,
allOf(
containsString("inference not possible. Task is configured with [pass_through] but received update of type [ner]"),
containsString("Trained model [deployed] is configured for task [pass_through] but called with task [ner]"),
containsString("error"),
not(containsString("warning"))
)
Expand Down Expand Up @@ -596,7 +593,7 @@ public void testPipelineWithBadProcessor() throws IOException {
""";

response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
assertThat(response, containsString("no value could be found for input field [input]"));
assertThat(response, containsString("Input field [input] does not exist in the source document"));
assertThat(response, containsString("status_exception"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ protected void doExecute(
);
return;
}
String message = "Cannot perform requested action because deployment [" + deploymentId + "] is not started";
String message = "Trained model [" + deploymentId + "] is not deployed";
listener.onFailure(ExceptionsHelper.conflictStatusException(message));
}, listener::onFailure));
return;
}
String[] randomRunningNode = allocation.getStartedNodes();
if (randomRunningNode.length == 0) {
String message = "Cannot perform requested action because deployment [" + deploymentId + "] is not yet running on any node";
String message = "Trained model [" + deploymentId + "] is not allocated to any nodes";
listener.onFailure(ExceptionsHelper.conflictStatusException(message));
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
}

private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> finalListener) {
logger.debug("[{}] Starting model deployment", task.getModelId());
logger.info("[{}] Starting model deployment", task.getModelId());

ProcessContext processContext = new ProcessContext(task, executorServiceForProcess);

if (processContextByAllocation.putIfAbsent(task.getId(), processContext) != null) {
finalListener.onFailure(ExceptionsHelper.serverError("[{}] Could not create process as one already exists", task.getModelId()));
finalListener.onFailure(
ExceptionsHelper.serverError("[{}] Could not create inference process as one already exists", task.getModelId())
);
return;
}

Expand Down Expand Up @@ -191,7 +193,7 @@ Vocabulary parseVocabularyDocLeniently(SearchHit hit) throws IOException {
) {
return Vocabulary.createParser(true).apply(parser, null);
} catch (IOException e) {
logger.error(new ParameterizedMessage("failed to parse vocabulary [{}]", hit.getId()), e);
logger.error(new ParameterizedMessage("failed to parse trained model vocabulary [{}]", hit.getId()), e);
throw e;
}
}
Expand All @@ -214,7 +216,7 @@ public void stopDeployment(TrainedModelDeploymentTask task) {
logger.info("[{}] Stopping deployment", task.getModelId());
processContext.stopProcess();
} else {
logger.debug("[{}] No process context to stop", task.getModelId());
logger.warn("[{}] No process context to stop", task.getModelId());
}
}

Expand Down Expand Up @@ -374,8 +376,8 @@ protected void doRun() throws Exception {
);
processContext.process.get().writeInferenceRequest(request.processInput);
} catch (IOException e) {
logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.task.getModelId()), e);
onFailure(ExceptionsHelper.serverError("error writing to process", e));
logger.error(new ParameterizedMessage("[{}] error writing to inference process", processContext.task.getModelId()), e);
onFailure(ExceptionsHelper.serverError("Error writing to inference process", e));
} catch (Exception e) {
onFailure(e);
}
Expand All @@ -389,7 +391,12 @@ private void processResult(
ActionListener<InferenceResults> resultsListener
) {
if (pyTorchResult.isError()) {
resultsListener.onFailure(new ElasticsearchStatusException(pyTorchResult.getError(), RestStatus.INTERNAL_SERVER_ERROR));
resultsListener.onFailure(
new ElasticsearchStatusException(
"Error in inference process: [" + pyTorchResult.getError() + "]",
RestStatus.INTERNAL_SERVER_ERROR
)
);
return;
}

Expand Down Expand Up @@ -428,7 +435,7 @@ class ProcessContext {
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
this.executorService = new ProcessWorkerExecutorService(
threadPool.getThreadContext(),
"pytorch_inference",
"inference process",
task.getParams().getQueueCapacity()
);
}
Expand Down Expand Up @@ -460,11 +467,11 @@ synchronized void stopProcess() {

private Consumer<String> onProcessCrash() {
return reason -> {
logger.error("[{}] process crashed due to reason [{}]", task.getModelId(), reason);
logger.error("[{}] inference process crashed due to reason [{}]", task.getModelId(), reason);
resultProcessor.stop();
executorService.shutdown();
processContextByAllocation.remove(task.getId());
task.setFailed("process crashed due to reason [" + reason + "]");
task.setFailed("inference process crashed due to reason [" + reason + "]");
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,13 @@ protected void onCancelled() {

public void infer(Map<String, Object> doc, InferenceConfigUpdate update, TimeValue timeout, ActionListener<InferenceResults> listener) {
if (inferenceConfigHolder.get() == null) {
listener.onFailure(
ExceptionsHelper.conflictStatusException("[{}] inference not possible against uninitialized model", params.getModelId())
);
listener.onFailure(ExceptionsHelper.conflictStatusException("Trained model [{}] is not initialized", params.getModelId()));
return;
}
if (update.isSupported(inferenceConfigHolder.get()) == false) {
listener.onFailure(
new ElasticsearchStatusException(
"[{}] inference not possible. Task is configured with [{}] but received update of type [{}]",
"Trained model [{}] is configured for task [{}] but called with task [{}]",
RestStatus.FORBIDDEN,
params.getModelId(),
inferenceConfigHolder.get().getName(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ private void loadModel(String modelId, Consumer consumer) {
handleLoadFailure(
modelId,
new ElasticsearchStatusException(
"model [{}] with type [{}] is currently not usable in search.",
"Trained model [{}] with type [{}] is currently not usable in search.",
RestStatus.BAD_REQUEST,
modelId,
trainedModelConfig.getModelType()
Expand All @@ -342,11 +342,7 @@ private void loadModel(String modelId, Consumer consumer) {
}
handleLoadFailure(
modelId,
new ElasticsearchStatusException(
"model [{}] must be deployed to use. Please deploy with the start trained model deployment API.",
RestStatus.BAD_REQUEST,
modelId
)
new ElasticsearchStatusException("Trained model [{}] is not deployed.", RestStatus.BAD_REQUEST, modelId)
);
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
package org.elasticsearch.xpack.ml.inference.nlp;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
Expand All @@ -37,21 +37,26 @@ public class FillMaskProcessor implements NlpTask.Processor {

@Override
public void validateInputs(List<String> inputs) {
ValidationException ve = new ValidationException();
if (inputs.isEmpty()) {
throw ExceptionsHelper.badRequestException("input request is empty");
ve.addValidationError("input request is empty");
}

for (String input : inputs) {
int maskIndex = input.indexOf(BertTokenizer.MASK_TOKEN);
if (maskIndex < 0) {
throw ExceptionsHelper.badRequestException("no {} token could be found", BertTokenizer.MASK_TOKEN);
ve.addValidationError("no " + BertTokenizer.MASK_TOKEN + " token could be found in the input");
}

maskIndex = input.indexOf(BertTokenizer.MASK_TOKEN, maskIndex + BertTokenizer.MASK_TOKEN.length());
if (maskIndex > 0) {
throw ExceptionsHelper.badRequestException("only one {} token should exist in the input", BertTokenizer.MASK_TOKEN);
throw ve.addValidationError("only one " + BertTokenizer.MASK_TOKEN + " token should exist in the input");
}
}

if (ve.validationErrors().isEmpty() == false) {
throw ve;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ public static String extractInput(TrainedModelInput input, Map<String, Object> d
String inputField = input.getFieldNames().get(0);
Object inputValue = XContentMapValues.extractValue(inputField, doc);
if (inputValue == null) {
throw ExceptionsHelper.badRequestException("no value could be found for input field [{}]", inputField);
throw ExceptionsHelper.badRequestException("Input field [{}] does not exist in the source document", inputField);
}
if (inputValue instanceof String) {
return (String) inputValue;
}
throw ExceptionsHelper.badRequestException("input value [{}] for field [{}] is not a string", inputValue, inputField);
throw ExceptionsHelper.badRequestException("Input value [{}] for field [{}] must be a string", inputValue, inputField);
}

public static class Request {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ public synchronized void execute(Runnable command) {

boolean added = queue.offer(contextHolder.preserveContext(command));
if (added == false) {
throw new ElasticsearchStatusException("Unable to execute on [{}] as queue is full", RestStatus.TOO_MANY_REQUESTS, processName);
throw new ElasticsearchStatusException(
processName + " queue is full. Unable to execute command",
RestStatus.TOO_MANY_REQUESTS,
processName
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.ml.inference.nlp;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
Expand Down Expand Up @@ -87,7 +88,7 @@ public void testValidate_GivenMissingMaskToken() {
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null);
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);

ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> processor.validateInputs(input));
ValidationException e = expectThrows(ValidationException.class, () -> processor.validateInputs(input));
assertThat(e.getMessage(), containsString("no [MASK] token could be found"));
}

Expand All @@ -97,7 +98,7 @@ public void testProcessResults_GivenMultipleMaskTokens() {
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null);
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);

ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> processor.validateInputs(input));
ValidationException e = expectThrows(ValidationException.class, () -> processor.validateInputs(input));
assertThat(e.getMessage(), containsString("only one [MASK] token should exist in the input"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public void testExtractInput_GivenFieldIsNotPresent() {
);

assertThat(e.status(), equalTo(RestStatus.BAD_REQUEST));
assertThat(e.getMessage(), equalTo("no value could be found for input field [" + fieldName + "]"));
assertThat(e.getMessage(), equalTo("Input field [" + fieldName + "] does not exist in the source document"));
}

public void testExtractInput_GivenFieldIsNotString() {
Expand All @@ -57,6 +57,6 @@ public void testExtractInput_GivenFieldIsNotString() {
);

assertThat(e.status(), equalTo(RestStatus.BAD_REQUEST));
assertThat(e.getMessage(), equalTo("input value [42] for field [" + fieldName + "] is not a string"));
assertThat(e.getMessage(), equalTo("Input value [42] for field [" + fieldName + "] must be a string"));
}
}

0 comments on commit 942ca30

Please sign in to comment.