Skip to content

Commit

Permalink
[7.x] [ML] calculate cache misses for inference and return in stats (#…
Browse files Browse the repository at this point in the history
…58252) (#58363)

When a local model is constructed, the cache hit miss count is incremented.

When a user calls _stats, we will include the sum cache hit miss count across ALL nodes. This statistic is important to in comparing against the inference_count. If the cache hit miss count is near the inference_count it indicates that the cache is overburdened, or inappropriately configured.
  • Loading branch information
benwtrent committed Jun 19, 2020
1 parent d8dc638 commit bf8641a
Show file tree
Hide file tree
Showing 14 changed files with 472 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.elasticsearch.client.ml.inference;

import org.elasticsearch.client.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
Expand All @@ -38,32 +39,36 @@ public class TrainedModelStats implements ToXContentObject {
public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
public static final ParseField INGEST_STATS = new ParseField("ingest");
public static final ParseField INFERENCE_STATS = new ParseField("inference_stats");

private final String modelId;
private final Map<String, Object> ingestStats;
private final int pipelineCount;
private final InferenceStats inferenceStats;

@SuppressWarnings("unchecked")
static final ConstructingObjectParser<TrainedModelStats, Void> PARSER =
new ConstructingObjectParser<>(
"trained_model_stats",
true,
args -> new TrainedModelStats((String) args[0], (Map<String, Object>) args[1], (Integer) args[2]));
args -> new TrainedModelStats((String) args[0], (Map<String, Object>) args[1], (Integer) args[2], (InferenceStats) args[3]));

static {
PARSER.declareString(constructorArg(), MODEL_ID);
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), INGEST_STATS);
PARSER.declareInt(constructorArg(), PIPELINE_COUNT);
PARSER.declareObject(optionalConstructorArg(), InferenceStats.PARSER, INFERENCE_STATS);
}

public static TrainedModelStats fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

public TrainedModelStats(String modelId, Map<String, Object> ingestStats, int pipelineCount) {
public TrainedModelStats(String modelId, Map<String, Object> ingestStats, int pipelineCount, InferenceStats inferenceStats) {
this.modelId = modelId;
this.ingestStats = ingestStats;
this.pipelineCount = pipelineCount;
this.inferenceStats = inferenceStats;
}

/**
Expand All @@ -89,6 +94,13 @@ public int getPipelineCount() {
return pipelineCount;
}

/**
* Inference statistics
*/
public InferenceStats getInferenceStats() {
return inferenceStats;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand All @@ -97,13 +109,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (ingestStats != null) {
builder.field(INGEST_STATS.getPreferredName(), ingestStats);
}
if (inferenceStats != null) {
builder.field(INFERENCE_STATS.getPreferredName(), inferenceStats);
}
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(modelId, ingestStats, pipelineCount);
return Objects.hash(modelId, ingestStats, pipelineCount, inferenceStats);
}

@Override
Expand All @@ -117,7 +132,8 @@ public boolean equals(Object obj) {
TrainedModelStats other = (TrainedModelStats) obj;
return Objects.equals(this.modelId, other.modelId)
&& Objects.equals(this.ingestStats, other.ingestStats)
&& Objects.equals(this.pipelineCount, other.pipelineCount);
&& Objects.equals(this.pipelineCount, other.pipelineCount)
&& Objects.equals(this.inferenceStats, other.inferenceStats);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.client.ml.inference.trainedmodel;

import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;

import java.io.IOException;
import java.time.Instant;
import java.util.Objects;

public class InferenceStats implements ToXContentObject {

public static final String NAME = "inference_stats";
public static final ParseField MISSING_ALL_FIELDS_COUNT = new ParseField("missing_all_fields_count");
public static final ParseField INFERENCE_COUNT = new ParseField("inference_count");
public static final ParseField CACHE_MISS_COUNT = new ParseField("cache_miss_count");
public static final ParseField FAILURE_COUNT = new ParseField("failure_count");
public static final ParseField TIMESTAMP = new ParseField("timestamp");

public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (Long)a[3], (Instant)a[4])
);
static {
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MISSING_ALL_FIELDS_COUNT);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), INFERENCE_COUNT);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), FAILURE_COUNT);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), CACHE_MISS_COUNT);
PARSER.declareField(ConstructingObjectParser.constructorArg(),
p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
TIMESTAMP,
ObjectParser.ValueType.VALUE);
}

private final long missingAllFieldsCount;
private final long inferenceCount;
private final long failureCount;
private final long cacheMissCount;
private final Instant timeStamp;

private InferenceStats(Long missingAllFieldsCount,
Long inferenceCount,
Long failureCount,
Long cacheMissCount,
Instant instant) {
this(unboxOrZero(missingAllFieldsCount),
unboxOrZero(inferenceCount),
unboxOrZero(failureCount),
unboxOrZero(cacheMissCount),
instant);
}

public InferenceStats(long missingAllFieldsCount,
long inferenceCount,
long failureCount,
long cacheMissCount,
Instant timeStamp) {
this.missingAllFieldsCount = missingAllFieldsCount;
this.inferenceCount = inferenceCount;
this.failureCount = failureCount;
this.cacheMissCount = cacheMissCount;
this.timeStamp = timeStamp == null ?
Instant.ofEpochMilli(Instant.now().toEpochMilli()) :
Instant.ofEpochMilli(timeStamp.toEpochMilli());
}

/**
* How many times this model attempted to infer with all its fields missing
*/
public long getMissingAllFieldsCount() {
return missingAllFieldsCount;
}

/**
* How many inference calls were made against this model
*/
public long getInferenceCount() {
return inferenceCount;
}

/**
* How many inference failures occurred.
*/
public long getFailureCount() {
return failureCount;
}

/**
* How many cache misses occurred when inferring this model
*/
public long getCacheMissCount() {
return cacheMissCount;
}

/**
* The timestamp of these statistics.
*/
public Instant getTimeStamp() {
return timeStamp;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FAILURE_COUNT.getPreferredName(), failureCount);
builder.field(INFERENCE_COUNT.getPreferredName(), inferenceCount);
builder.field(CACHE_MISS_COUNT.getPreferredName(), cacheMissCount);
builder.field(MISSING_ALL_FIELDS_COUNT.getPreferredName(), missingAllFieldsCount);
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timeStamp.toEpochMilli());
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceStats that = (InferenceStats) o;
return missingAllFieldsCount == that.missingAllFieldsCount
&& inferenceCount == that.inferenceCount
&& failureCount == that.failureCount
&& cacheMissCount == that.cacheMissCount
&& Objects.equals(timeStamp, that.timeStamp);
}

@Override
public int hashCode() {
return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, cacheMissCount, timeStamp);
}

@Override
public String toString() {
return "InferenceStats{" +
"missingAllFieldsCount=" + missingAllFieldsCount +
", inferenceCount=" + inferenceCount +
", failureCount=" + failureCount +
", cacheMissCount=" + cacheMissCount +
", timeStamp=" + timeStamp +
'}';
}

private static long unboxOrZero(@Nullable Long value) {
return value == null ? 0L : value;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.elasticsearch.client.ml.inference;

import org.elasticsearch.client.ml.inference.trainedmodel.InferenceStatsTests;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand Down Expand Up @@ -58,7 +59,8 @@ protected TrainedModelStats createTestInstance() {
return new TrainedModelStats(
randomAlphaOfLength(10),
randomBoolean() ? null : randomIngestStats(),
randomInt());
randomInt(),
randomBoolean() ? null : InferenceStatsTests.randomInstance());
}

private Map<String, Object> randomIngestStats() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.client.ml.inference.trainedmodel;

import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;

import java.io.IOException;
import java.time.Instant;

public class InferenceStatsTests extends AbstractXContentTestCase<InferenceStats> {

public static InferenceStats randomInstance() {
return new InferenceStats(randomNonNegativeLong(),
randomNonNegativeLong(),
randomNonNegativeLong(),
randomNonNegativeLong(),
Instant.now()
);
}

@Override
protected InferenceStats doParseInstance(XContentParser parser) throws IOException {
return InferenceStats.PARSER.apply(parser, null);
}

@Override
protected boolean supportsUnknownFields() {
return true;
}

@Override
protected InferenceStats createTestInstance() {
return randomInstance();
}

}

0 comments on commit bf8641a

Please sign in to comment.