Skip to content

Commit

Permalink
[7.x][ML] Add status and increased estimate to memory usage (#58588) (#…
Browse files Browse the repository at this point in the history
…58606)

Adds parsing of `status` and `memory_reestimate_bytes`
to data frame analytics `memory_usage`. When the training surpasses
the model memory limit, the status will be set to `hard_limit` and
`memory_reestimate_bytes` can be used to update the job's
limit in order to restart the job.

Backport of #58588
  • Loading branch information
dimitris-athanasiou committed Jun 28, 2020
1 parent 3c81b91 commit 1817b89
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,49 @@
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.time.Instant;
import java.util.Locale;
import java.util.Objects;

public class MemoryUsage implements ToXContentObject {

static final ParseField TIMESTAMP = new ParseField("timestamp");
static final ParseField PEAK_USAGE_BYTES = new ParseField("peak_usage_bytes");
static final ParseField STATUS = new ParseField("status");
static final ParseField MEMORY_REESTIMATE_BYTES = new ParseField("memory_reestimate_bytes");

public static final ConstructingObjectParser<MemoryUsage, Void> PARSER = new ConstructingObjectParser<>("analytics_memory_usage",
true, a -> new MemoryUsage((Instant) a[0], (long) a[1]));
true, a -> new MemoryUsage((Instant) a[0], (long) a[1], (Status) a[2], (Long) a[3]));

static {
PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(),
p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
TIMESTAMP,
ObjectParser.ValueType.VALUE);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), PEAK_USAGE_BYTES);
PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return Status.fromString(p.text());
}
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
}, STATUS, ObjectParser.ValueType.STRING);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), MEMORY_REESTIMATE_BYTES);
}

@Nullable
private final Instant timestamp;
private final long peakUsageBytes;
private final Status status;
private final Long memoryReestimateBytes;

public MemoryUsage(@Nullable Instant timestamp, long peakUsageBytes) {
public MemoryUsage(@Nullable Instant timestamp, long peakUsageBytes, Status status, @Nullable Long memoryReestimateBytes) {
this.timestamp = timestamp == null ? null : Instant.ofEpochMilli(Objects.requireNonNull(timestamp).toEpochMilli());
this.peakUsageBytes = peakUsageBytes;
this.status = status;
this.memoryReestimateBytes = memoryReestimateBytes;
}

@Nullable
Expand All @@ -65,13 +80,25 @@ public long getPeakUsageBytes() {
return peakUsageBytes;
}

public Status getStatus() {
return status;
}

public Long getMemoryReestimateBytes() {
return memoryReestimateBytes;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (timestamp != null) {
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
}
builder.field(PEAK_USAGE_BYTES.getPreferredName(), peakUsageBytes);
builder.field(STATUS.getPreferredName(), status);
if (memoryReestimateBytes != null) {
builder.field(MEMORY_REESTIMATE_BYTES.getPreferredName(), memoryReestimateBytes);
}
builder.endObject();
return builder;
}
Expand All @@ -83,19 +110,37 @@ public boolean equals(Object o) {

MemoryUsage other = (MemoryUsage) o;
return Objects.equals(timestamp, other.timestamp)
&& peakUsageBytes == other.peakUsageBytes;
&& peakUsageBytes == other.peakUsageBytes
&& Objects.equals(status, other.status)
&& Objects.equals(memoryReestimateBytes, other.memoryReestimateBytes);
}

@Override
public int hashCode() {
return Objects.hash(timestamp, peakUsageBytes);
return Objects.hash(timestamp, peakUsageBytes, status, memoryReestimateBytes);
}

@Override
public String toString() {
return new ToStringBuilder(getClass())
.add(TIMESTAMP.getPreferredName(), timestamp == null ? null : timestamp.getEpochSecond())
.add(PEAK_USAGE_BYTES.getPreferredName(), peakUsageBytes)
.add(STATUS.getPreferredName(), status)
.add(MEMORY_REESTIMATE_BYTES.getPreferredName(), memoryReestimateBytes)
.toString();
}

public enum Status {
OK,
HARD_LIMIT;

public static Status fromString(String value) {
return valueOf(value.toUpperCase(Locale.ROOT));
}

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
import org.elasticsearch.client.ml.dataframe.stats.common.DataCounts;
import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
Expand Down Expand Up @@ -1567,6 +1568,8 @@ public void testGetDataFrameAnalyticsStats() throws Exception {
assertThat(progress.get(2), equalTo(new PhaseProgress("computing_outliers", 0)));
assertThat(progress.get(3), equalTo(new PhaseProgress("writing_results", 0)));
assertThat(stats.getMemoryUsage().getPeakUsageBytes(), equalTo(0L));
assertThat(stats.getMemoryUsage().getStatus(), equalTo(MemoryUsage.Status.OK));
assertThat(stats.getMemoryUsage().getMemoryReestimateBytes(), is(nullValue()));
assertThat(stats.getDataCounts(), equalTo(new DataCounts(0, 0, 0)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ protected MemoryUsage createTestInstance() {
}

public static MemoryUsage createRandom() {
return new MemoryUsage(randomBoolean() ? null : Instant.now(), randomNonNegativeLong());
return new MemoryUsage(
randomBoolean() ? null : Instant.now(),
randomNonNegativeLong(),
randomFrom(MemoryUsage.Status.values()),
randomBoolean() ? null : randomNonNegativeLong()
);
}

@Override
Expand All @@ -48,7 +53,8 @@ protected boolean supportsUnknownFields() {
}

public void testToString_GivenNullTimestamp() {
MemoryUsage memoryUsage = new MemoryUsage(null, 42L);
assertThat(memoryUsage.toString(), equalTo("MemoryUsage[timestamp=null, peak_usage_bytes=42]"));
MemoryUsage memoryUsage = new MemoryUsage(null, 42L, MemoryUsage.Status.OK, null);
assertThat(memoryUsage.toString(), equalTo(
"MemoryUsage[timestamp=null, peak_usage_bytes=42, status=ok, memory_reestimate_bytes=null]"));
}
}
18 changes: 16 additions & 2 deletions docs/reference/ml/df-analytics/apis/get-dfanalytics-stats.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,25 @@ job is started and memory usage is reported.
.Properties of `memory_usage`
[%collapsible%open]
=====
`peak_usage_bytes`:::
`memory_reestimate_bytes`::::
(long)
This value is present when the `status` is `hard_limit` and it
is a new estimate of how much memory the job needs.

`peak_usage_bytes`::::
(long)
The number of bytes used at the highest peak of memory usage.

`timestamp`:::
`status`::::
(string)
The memory usage status. May have one of the following values:
+
--
* `ok`: usage stayed below the limit.
* `hard_limit`: usage surpassed the configured memory limit.
--

`timestamp`::::
(date)
The timestamp when memory usage was calculated.
=====
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.common;

import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
Expand All @@ -14,27 +16,31 @@
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.common.time.TimeUtils;
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;

import java.io.IOException;
import java.time.Instant;
import java.util.Locale;
import java.util.Objects;

public class MemoryUsage implements Writeable, ToXContentObject {

public static final String TYPE_VALUE = "analytics_memory_usage";

public static final ParseField PEAK_USAGE_BYTES = new ParseField("peak_usage_bytes");
public static final ParseField STATUS = new ParseField("status");
public static final ParseField MEMORY_REESTIMATE_BYTES = new ParseField("memory_reestimate_bytes");

public static final ConstructingObjectParser<MemoryUsage, Void> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<MemoryUsage, Void> LENIENT_PARSER = createParser(true);

private static ConstructingObjectParser<MemoryUsage, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<MemoryUsage, Void> parser = new ConstructingObjectParser<>(TYPE_VALUE,
ignoreUnknownFields, a -> new MemoryUsage((String) a[0], (Instant) a[1], (long) a[2]));
ignoreUnknownFields, a -> new MemoryUsage((String) a[0], (Instant) a[1], (long) a[2], (Status) a[3], (Long) a[4]));

parser.declareString((bucket, s) -> {}, Fields.TYPE);
parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID);
Expand All @@ -43,6 +49,13 @@ private static ConstructingObjectParser<MemoryUsage, Void> createParser(boolean
Fields.TIMESTAMP,
ObjectParser.ValueType.VALUE);
parser.declareLong(ConstructingObjectParser.constructorArg(), PEAK_USAGE_BYTES);
parser.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return Status.fromString(p.text());
}
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
}, STATUS, ObjectParser.ValueType.STRING);
parser.declareLong(ConstructingObjectParser.optionalConstructorArg(), MEMORY_REESTIMATE_BYTES);
return parser;
}

Expand All @@ -52,34 +65,54 @@ private static ConstructingObjectParser<MemoryUsage, Void> createParser(boolean
*/
private final Instant timestamp;
private final long peakUsageBytes;
private final Status status;
@Nullable private final Long memoryReestimateBytes;

/**
* Creates a zero usage object
*/
public MemoryUsage(String jobId) {
this(jobId, null, 0);
this(jobId, null, 0, null, null);
}

public MemoryUsage(String jobId, Instant timestamp, long peakUsageBytes) {
public MemoryUsage(String jobId, Instant timestamp, long peakUsageBytes, @Nullable Status status,
@Nullable Long memoryReestimateBytes) {
this.jobId = Objects.requireNonNull(jobId);
// We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure
// internal representation matches toXContent
this.timestamp = timestamp == null ? null : Instant.ofEpochMilli(
ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli());
this.peakUsageBytes = peakUsageBytes;
this.status = status == null ? Status.OK : status;
this.memoryReestimateBytes = memoryReestimateBytes;
}

public MemoryUsage(StreamInput in) throws IOException {
jobId = in.readString();
timestamp = in.readOptionalInstant();
peakUsageBytes = in.readVLong();
if (in.getVersion().onOrAfter(Version.V_7_9_0)) {
status = Status.readFromStream(in);
memoryReestimateBytes = in.readOptionalVLong();
} else {
status = Status.OK;
memoryReestimateBytes = null;
}
}

public Status getStatus() {
return status;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(jobId);
out.writeOptionalInstant(timestamp);
out.writeVLong(peakUsageBytes);
if (out.getVersion().onOrAfter(Version.V_7_9_0)) {
status.writeTo(out);
out.writeOptionalVLong(memoryReestimateBytes);
}
}

@Override
Expand All @@ -94,6 +127,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
timestamp.toEpochMilli());
}
builder.field(PEAK_USAGE_BYTES.getPreferredName(), peakUsageBytes);
builder.field(STATUS.getPreferredName(), status);
if (memoryReestimateBytes != null) {
builder.field(MEMORY_REESTIMATE_BYTES.getPreferredName(), memoryReestimateBytes);
}
builder.endObject();
return builder;
}
Expand All @@ -106,12 +143,14 @@ public boolean equals(Object o) {
MemoryUsage other = (MemoryUsage) o;
return Objects.equals(jobId, other.jobId)
&& Objects.equals(timestamp, other.timestamp)
&& peakUsageBytes == other.peakUsageBytes;
&& peakUsageBytes == other.peakUsageBytes
&& Objects.equals(status, other.status)
&& Objects.equals(memoryReestimateBytes, other.memoryReestimateBytes);
}

@Override
public int hashCode() {
return Objects.hash(jobId, timestamp, peakUsageBytes);
return Objects.hash(jobId, timestamp, peakUsageBytes, status, memoryReestimateBytes);
}

@Override
Expand All @@ -127,4 +166,27 @@ public String documentId(String jobId) {
public static String documentIdPrefix(String jobId) {
return TYPE_VALUE + "_" + jobId + "_";
}

public enum Status implements Writeable {
OK,
HARD_LIMIT;

public static Status fromString(String value) {
return valueOf(value.toUpperCase(Locale.ROOT));
}

public static Status readFromStream(StreamInput in) throws IOException {
return in.readEnum(Status.class);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(this);
}

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ protected ToXContent.Params getToXContentParams() {
}

public static MemoryUsage createRandom() {
return new MemoryUsage(randomAlphaOfLength(10), Instant.now(), randomNonNegativeLong());
return new MemoryUsage(
randomAlphaOfLength(10),
Instant.now(),
randomNonNegativeLong(),
randomBoolean() ? null : randomFrom(MemoryUsage.Status.values()),
randomBoolean() ? null : randomNonNegativeLong()
);
}

@Override
Expand All @@ -60,6 +66,6 @@ protected MemoryUsage createTestInstance() {
public void testZeroUsage() {
MemoryUsage memoryUsage = new MemoryUsage("zero_usage_job");
String asJson = Strings.toString(memoryUsage);
assertThat(asJson, equalTo("{\"peak_usage_bytes\":0}"));
assertThat(asJson, equalTo("{\"peak_usage_bytes\":0,\"status\":\"ok\"}"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo
}
MemoryUsage memoryUsage = result.getMemoryUsage();
if (memoryUsage != null) {
statsHolder.setMemoryUsage(memoryUsage);
statsPersister.persistWithRetry(memoryUsage, memoryUsage::documentId);
processMemoryUsage(memoryUsage);
}
OutlierDetectionStats outlierDetectionStats = result.getOutlierDetectionStats();
if (outlierDetectionStats != null) {
Expand Down Expand Up @@ -273,4 +272,9 @@ private void setAndReportFailure(Exception e) {
failure = "error processing results; " + e.getMessage();
auditor.error(analytics.getId(), "Error processing results; " + e.getMessage());
}

private void processMemoryUsage(MemoryUsage memoryUsage) {
statsHolder.setMemoryUsage(memoryUsage);
statsPersister.persistWithRetry(memoryUsage, memoryUsage::documentId);
}
}

0 comments on commit 1817b89

Please sign in to comment.