Skip to content

Commit

Permalink
added javadoc and renamed validateParameters function
Browse files Browse the repository at this point in the history
  • Loading branch information
maxhniebergall committed Jun 18, 2024
1 parent ddd8428 commit e57abc1
Showing 1 changed file with 19 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.BaseRestHandler;
Expand Down Expand Up @@ -86,7 +87,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
}

if (restRequest.hasParam(TIMEOUT.getPreferredName())) {
TimeValue openTimeout = sameParamInQueryAndBody(
TimeValue openTimeout = validateParameters(
request.getTimeout(),
restRequest.paramAsTime(TIMEOUT.getPreferredName(), StartTrainedModelDeploymentAction.DEFAULT_TIMEOUT),
StartTrainedModelDeploymentAction.DEFAULT_TIMEOUT
Expand All @@ -95,7 +96,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
}

request.setWaitForState(
sameParamInQueryAndBody(
validateParameters(
request.getWaitForState(),
AllocationStatus.State.fromString(
restRequest.param(WAIT_FOR.getPreferredName(), StartTrainedModelDeploymentAction.DEFAULT_WAITFOR_STATE.toString())
Expand All @@ -109,7 +110,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
NUMBER_OF_ALLOCATIONS.getPreferredName(),
RestApiVersion.V_8,
restRequest,
(r, s) -> sameParamInQueryAndBody(
(r, s) -> validateParameters(
request.getNumberOfAllocations(),
r.paramAsInt(s, StartTrainedModelDeploymentAction.DEFAULT_NUM_ALLOCATIONS),
StartTrainedModelDeploymentAction.DEFAULT_NUM_ALLOCATIONS
Expand All @@ -121,15 +122,15 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
THREADS_PER_ALLOCATION.getPreferredName(),
RestApiVersion.V_8,
restRequest,
(r, s) -> sameParamInQueryAndBody(
(r, s) -> validateParameters(
request.getThreadsPerAllocation(),
r.paramAsInt(s, StartTrainedModelDeploymentAction.DEFAULT_NUM_THREADS),
StartTrainedModelDeploymentAction.DEFAULT_NUM_THREADS
),
request::setThreadsPerAllocation
);
request.setQueueCapacity(
sameParamInQueryAndBody(
validateParameters(
request.getQueueCapacity(),
restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), StartTrainedModelDeploymentAction.DEFAULT_QUEUE_CAPACITY),
StartTrainedModelDeploymentAction.DEFAULT_QUEUE_CAPACITY
Expand All @@ -138,7 +139,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient

if (restRequest.hasParam(CACHE_SIZE.getPreferredName())) {
request.setCacheSize(
sameParamInQueryAndBody(
validateParameters(
request.getCacheSize(),
ByteSizeValue.parseBytesSizeValue(restRequest.param(CACHE_SIZE.getPreferredName()), CACHE_SIZE.getPreferredName()),
null
Expand All @@ -149,7 +150,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
}

request.setPriority(
sameParamInQueryAndBody(
validateParameters(
request.getPriority().toString(),
restRequest.param(StartTrainedModelDeploymentAction.TaskParams.PRIORITY.getPreferredName()),
StartTrainedModelDeploymentAction.DEFAULT_PRIORITY.toString()
Expand All @@ -159,7 +160,17 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));
}

private static <T> T sameParamInQueryAndBody(T bodyParam, T queryParam, T paramDefault) {
/**
* This function validates that the body and query parameters don't conflict, and returns the value that should be used.
* When using this function, the body parameter should already have been set to the default value in
* {@link StartTrainedModelDeploymentAction}, or, set to a different value from the rest request.
*
* @param paramDefault (from {@link StartTrainedModelDeploymentAction})
* @return the parameter to use
* @throws ElasticsearchStatusException if the parameters don't match
*/
private static <T> T validateParameters(@Nullable T bodyParam, @Nullable T queryParam, @Nullable T paramDefault)
throws ElasticsearchStatusException {
if (Objects.equals(bodyParam, paramDefault) && queryParam != null) {
// the body param is the same as the default for this value. We cannot tell if this was set intentionally, or if it was just the
// default, thus we will assume it was the default
Expand Down

0 comments on commit e57abc1

Please sign in to comment.