Skip to content

Commit

Permalink
[api] Adds load model from Inpustream to public API
Browse files Browse the repository at this point in the history
Change-Id: I2106712d175ea1c6ec31d9561f2809473df2fbc9
  • Loading branch information
frankfliu committed Dec 3, 2021
1 parent be5289d commit cbaef03
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 10 deletions.
7 changes: 7 additions & 0 deletions api/src/main/java/ai/djl/BaseModel.java
Expand Up @@ -116,6 +116,13 @@ public DataType getDataType() {
return dataType;
}

/** {@inheritDoc} */
@Override
public void load(InputStream is, Map<String, ?> options)
throws IOException, MalformedModelException {
throw new UnsupportedOperationException("Not supported!");
}

/** {@inheritDoc} */
@Override
public void close() {
Expand Down
21 changes: 21 additions & 0 deletions api/src/main/java/ai/djl/Model.java
Expand Up @@ -133,6 +133,27 @@ default void load(Path modelPath, String prefix) throws IOException, MalformedMo
void load(Path modelPath, String prefix, Map<String, ?> options)
throws IOException, MalformedModelException;

/**
* Loads the model from the {@code InputStream}.
*
* @param is the {@code InputStream} to load the model from
* @throws IOException when IO operation fails in loading a resource
* @throws MalformedModelException if model file is corrupted
*/
default void load(InputStream is) throws IOException, MalformedModelException {
load(is, null);
}

/**
* Loads the model from the {@code InputStream} with the options provided.
*
* @param is the {@code InputStream} to load the model from
* @param options engine specific load model options, see documentation for each engine
* @throws IOException when IO operation fails in loading a resource
* @throws MalformedModelException if model file is corrupted
*/
void load(InputStream is, Map<String, ?> options) throws IOException, MalformedModelException;

/**
* Saves the model to the specified {@code modelPath} with the name provided.
*
Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/repository/zoo/ZooModel.java
Expand Up @@ -59,6 +59,12 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) {
throw new IllegalArgumentException("ZooModel should not be re-loaded.");
}

/** {@inheritDoc} */
@Override
public void load(InputStream modelStream, Map<String, ?> options) throws IOException {
throw new IllegalArgumentException("ZooModel should not be re-loaded.");
}

/**
* Returns the wrapped model.
*
Expand Down
Expand Up @@ -108,16 +108,14 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
}
}

/**
* Load PyTorch model from {@link InputStream}.
*
* <p>Currently, only TorchScript file are supported
*
* @param modelStream the stream of the model file
* @throws IOException model loading error
*/
public void load(InputStream modelStream) throws IOException {
load(modelStream, true);
/** {@inheritDoc} */
@Override
public void load(InputStream modelStream, Map<String, ?> options) throws IOException {
boolean mapLocation = true;
if (options != null && options.containsKey("mapLocation")) {
mapLocation = Boolean.parseBoolean(options.get("mapLocation").toString());
}
load(modelStream, mapLocation);
}

/**
Expand All @@ -128,6 +126,8 @@ public void load(InputStream modelStream) throws IOException {
* @throws IOException model loading error
*/
public void load(InputStream modelStream, boolean mapLocation) throws IOException {
modelDir = Files.createTempDirectory("pt-model");
modelDir.toFile().deleteOnExit();
block = JniUtils.loadModule((PtNDManager) manager, modelStream, mapLocation, false);
}

Expand Down
Expand Up @@ -93,6 +93,12 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
properties.put("model-type", fta.getModelType());
}

/** {@inheritDoc} */
@Override
public void load(InputStream modelStream, Map<String, ?> options) {
throw new UnsupportedOperationException("Not supported.");
}

private Path findModelFile(String prefix) {
if (Files.isRegularFile(modelDir)) {
Path file = modelDir;
Expand Down

0 comments on commit cbaef03

Please sign in to comment.