Skip to content

Commit

Permalink
Accounting for model size when models are not cached (#59607)
Browse files Browse the repository at this point in the history
When an inference model is loaded it is accounted for in circuit breaker
and should not be released until there are no users of the model. Adds
a reference count to the model to track usage.
  • Loading branch information
davidkyle committed Jul 15, 2020
1 parent c61d71b commit 8fcd861
Show file tree
Hide file tree
Showing 7 changed files with 348 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,14 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
model.infer(stringObjectMap, request.getUpdate(), chainedTask)));

typedChainTaskExecutor.execute(ActionListener.wrap(
inferenceResultsInterfaces ->
listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build()),
listener::onFailure
inferenceResultsInterfaces -> {
model.release();
listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build());
},
e -> {
model.release();
listener.onFailure(e);
}
));
},
listener::onFailure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ public void run(String modelId) {
LOGGER.info("[{}] Started inference on test data against model [{}]", config.getId(), modelId);
try {
PlainActionFuture<LocalModel> localModelPlainActionFuture = new PlainActionFuture<>();
modelLoadingService.getModelForSearch(modelId, localModelPlainActionFuture);
LocalModel localModel = localModelPlainActionFuture.actionGet();
modelLoadingService.getModelForPipeline(modelId, localModelPlainActionFuture);
TestDocsIterator testDocsIterator = new TestDocsIterator(new OriginSettingClient(client, ClientHelper.ML_ORIGIN), config);
inferTestDocs(localModel, testDocsIterator);
try (LocalModel localModel = localModelPlainActionFuture.actionGet()) {
inferTestDocs(localModel, testDocsIterator);
}
} catch (Exception e) {
throw ExceptionsHelper.serverError("[{}] failed running inference on model [{}]", e, config.getId(), modelId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,69 +49,76 @@ public InferencePipelineAggregator(String name, Map<String,
@Override
public InternalAggregation reduce(InternalAggregation aggregation, InternalAggregation.ReduceContext reduceContext) {

InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket> originalAgg =
(InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket>) aggregation;
List<? extends InternalMultiBucketAggregation.InternalBucket> buckets = originalAgg.getBuckets();

List<InternalMultiBucketAggregation.InternalBucket> newBuckets = new ArrayList<>();
for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
Map<String, Object> inputFields = new HashMap<>();

if (bucket.getDocCount() == 0) {
// ignore this empty bucket unless the doc count is used
if (bucketPathMap.containsKey("_count") == false) {
newBuckets.add(bucket);
continue;
try {
InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket> originalAgg =
(InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket>) aggregation;
List<? extends InternalMultiBucketAggregation.InternalBucket> buckets = originalAgg.getBuckets();

List<InternalMultiBucketAggregation.InternalBucket> newBuckets = new ArrayList<>();
for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
Map<String, Object> inputFields = new HashMap<>();

if (bucket.getDocCount() == 0) {
// ignore this empty bucket unless the doc count is used
if (bucketPathMap.containsKey("_count") == false) {
newBuckets.add(bucket);
continue;
}
}
}

for (Map.Entry<String, String> entry : bucketPathMap.entrySet()) {
String aggName = entry.getKey();
String bucketPath = entry.getValue();
Object propertyValue = resolveBucketValue(originalAgg, bucket, bucketPath);

if (propertyValue instanceof Number) {
double doubleVal = ((Number) propertyValue).doubleValue();
// NaN or infinite values indicate a missing value or a
// valid result of an invalid calculation. Either way only
// a valid number will do
if (Double.isFinite(doubleVal)) {
inputFields.put(aggName, doubleVal);
for (Map.Entry<String, String> entry : bucketPathMap.entrySet()) {
String aggName = entry.getKey();
String bucketPath = entry.getValue();
Object propertyValue = resolveBucketValue(originalAgg, bucket, bucketPath);

if (propertyValue instanceof Number) {
double doubleVal = ((Number) propertyValue).doubleValue();
// NaN or infinite values indicate a missing value or a
// valid result of an invalid calculation. Either way only
// a valid number will do
if (Double.isFinite(doubleVal)) {
inputFields.put(aggName, doubleVal);
}
} else if (propertyValue instanceof InternalNumericMetricsAggregation.SingleValue) {
double doubleVal = ((InternalNumericMetricsAggregation.SingleValue) propertyValue).value();
if (Double.isFinite(doubleVal)) {
inputFields.put(aggName, doubleVal);
}
} else if (propertyValue instanceof StringTerms.Bucket) {
StringTerms.Bucket b = (StringTerms.Bucket) propertyValue;
inputFields.put(aggName, b.getKeyAsString());
} else if (propertyValue instanceof String) {
inputFields.put(aggName, propertyValue);
} else if (propertyValue != null) {
// Doubles, String terms or null are valid, any other type is an error
throw invalidAggTypeError(bucketPath, propertyValue);
}
} else if (propertyValue instanceof InternalNumericMetricsAggregation.SingleValue) {
double doubleVal = ((InternalNumericMetricsAggregation.SingleValue) propertyValue).value();
if (Double.isFinite(doubleVal)) {
inputFields.put(aggName, doubleVal);
}
} else if (propertyValue instanceof StringTerms.Bucket) {
StringTerms.Bucket b = (StringTerms.Bucket) propertyValue;
inputFields.put(aggName, b.getKeyAsString());
} else if (propertyValue instanceof String) {
inputFields.put(aggName, propertyValue);
} else if (propertyValue != null) {
// Doubles, String terms or null are valid, any other type is an error
throw invalidAggTypeError(bucketPath, propertyValue);
}
}


InferenceResults inference;
try {
inference = model.infer(inputFields, configUpdate);
} catch (Exception e) {
inference = new WarningInferenceResults(e.getMessage());
InferenceResults inference;
try {
inference = model.infer(inputFields, configUpdate);
} catch (Exception e) {
inference = new WarningInferenceResults(e.getMessage());
}

final List<InternalAggregation> aggs = bucket.getAggregations().asList().stream().map(
(p) -> (InternalAggregation) p).collect(Collectors.toList());

InternalInferenceAggregation aggResult = new InternalInferenceAggregation(name(), metadata(), inference);
aggs.add(aggResult);
InternalMultiBucketAggregation.InternalBucket newBucket = originalAgg.createBucket(InternalAggregations.from(aggs), bucket);
newBuckets.add(newBucket);
}

final List<InternalAggregation> aggs = bucket.getAggregations().asList().stream().map(
(p) -> (InternalAggregation) p).collect(Collectors.toList());
// the model is released at the end of this block.
assert model.getReferenceCount() > 0;

InternalInferenceAggregation aggResult = new InternalInferenceAggregation(name(), metadata(), inference);
aggs.add(aggResult);
InternalMultiBucketAggregation.InternalBucket newBucket = originalAgg.createBucket(InternalAggregations.from(aggs), bucket);
newBuckets.add(newBucket);
return originalAgg.create(newBuckets);
} finally {
model.release();
}

return originalAgg.create(newBuckets);
}

public static Object resolveBucketValue(MultiBucketsAggregation agg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.license.License;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
Expand All @@ -19,16 +20,29 @@
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;

import java.io.Closeable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.LongAdder;

import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;

public class LocalModel {
/**
* LocalModels implement reference counting for proper accounting in
* the {@link CircuitBreaker}. When the model is not longer used {@link #release()}
* must be called and if the reference count == 0 then the model's bytes
* will be removed from the circuit breaker.
*
* The class is constructed with an initial reference count of 1 and its
* bytes <em>must</em> have been added to the circuit breaker before construction.
* New references must call {@link #acquire()} and {@link #release()} as the model
* is used.
*/
public class LocalModel implements Closeable {

private final InferenceDefinition trainedModelDefinition;
private final String modelId;
Expand All @@ -40,15 +54,18 @@ public class LocalModel {
private final LongAdder currentInferenceCount;
private final InferenceConfig inferenceConfig;
private final License.OperationMode licenseLevel;
private final CircuitBreaker trainedModelCircuitBreaker;
private final AtomicLong referenceCount;

public LocalModel(String modelId,
LocalModel(String modelId,
String nodeId,
InferenceDefinition trainedModelDefinition,
TrainedModelInput input,
Map<String, String> defaultFieldMap,
InferenceConfig modelInferenceConfig,
License.OperationMode licenseLevel,
TrainedModelStatsService trainedModelStatsService) {
TrainedModelStatsService trainedModelStatsService,
CircuitBreaker trainedModelCircuitBreaker) {
this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId;
this.fieldNames = new HashSet<>(input.getFieldNames());
Expand All @@ -60,6 +77,8 @@ public LocalModel(String modelId,
this.currentInferenceCount = new LongAdder();
this.inferenceConfig = modelInferenceConfig;
this.licenseLevel = licenseLevel;
this.trainedModelCircuitBreaker = trainedModelCircuitBreaker;
this.referenceCount = new AtomicLong(1);
}

long ramBytesUsed() {
Expand Down Expand Up @@ -177,4 +196,36 @@ public static void mapFieldsIfNecessary(Map<String, Object> fields, Map<String,
});
}
}

long acquire() {
long count = referenceCount.incrementAndGet();
// protect against a race where the model could be release to a
// count of zero then the model is quickly re-acquired
if (count == 1) {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelDefinition.ramBytesUsed(), modelId);
}
return count;
}

public long getReferenceCount() {
return referenceCount.get();
}

public long release() {
long count = referenceCount.decrementAndGet();
assert count >= 0;
if (count == 0) {
// no references to this model, it no longer needs to be accounted for
trainedModelCircuitBreaker.addWithoutBreaking(-ramBytesUsed());
}
return referenceCount.get();
}

/**
* Convenience method so the class can be used in try-with-resource
* constructs to invoke {@link #release()}.
*/
public void close() {
release();
}
}

0 comments on commit 8fcd861

Please sign in to comment.