-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
) This improves reporting of trained model size in the response of the stats API. In particular, it removes the `model_size_bytes` from the `deployment_stats` section and replaces it with a top-level `model_size_stats` object that contains: - `model_size_bytes`: the actual model size - `required_native_memory_bytes`: the amount of memory required to load a model In addition, these are now reported for PyTorch models regardless of their deployment state. Backport of #82000
- Loading branch information
1 parent
3c21d01
commit c73d651
Showing
14 changed files
with
342 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
70 changes: 70 additions & 0 deletions
70
...in/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModelSizeStats.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel; | ||
|
||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.common.io.stream.Writeable; | ||
import org.elasticsearch.common.unit.ByteSizeValue; | ||
import org.elasticsearch.xcontent.ParseField; | ||
import org.elasticsearch.xcontent.ToXContentObject; | ||
import org.elasticsearch.xcontent.XContentBuilder; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
|
||
public class TrainedModelSizeStats implements ToXContentObject, Writeable { | ||
|
||
private static final ParseField MODEL_SIZE_BYTES = new ParseField("model_size_bytes"); | ||
private static final ParseField REQUIRED_NATIVE_MEMORY_BYTES = new ParseField("required_native_memory_bytes"); | ||
|
||
private final long modelSizeBytes; | ||
private final long requiredNativeMemoryBytes; | ||
|
||
public TrainedModelSizeStats(long modelSizeBytes, long requiredNativeMemoryBytes) { | ||
this.modelSizeBytes = modelSizeBytes; | ||
this.requiredNativeMemoryBytes = requiredNativeMemoryBytes; | ||
} | ||
|
||
public TrainedModelSizeStats(StreamInput in) throws IOException { | ||
modelSizeBytes = in.readLong(); | ||
requiredNativeMemoryBytes = in.readLong(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeLong(modelSizeBytes); | ||
out.writeLong(requiredNativeMemoryBytes); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.humanReadableField(MODEL_SIZE_BYTES.getPreferredName(), "model_size", ByteSizeValue.ofBytes(modelSizeBytes)); | ||
builder.humanReadableField( | ||
REQUIRED_NATIVE_MEMORY_BYTES.getPreferredName(), | ||
"required_native_memory", | ||
ByteSizeValue.ofBytes(requiredNativeMemoryBytes) | ||
); | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
TrainedModelSizeStats that = (TrainedModelSizeStats) o; | ||
return modelSizeBytes == that.modelSizeBytes && requiredNativeMemoryBytes == that.requiredNativeMemoryBytes; | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(modelSizeBytes, requiredNativeMemoryBytes); | ||
} | ||
} |
Oops, something went wrong.