Skip to content

Commit

Permalink
move InputStream to byte-buffer conversion
Browse files Browse the repository at this point in the history
- move it from Booster to XGBoost facade class
  • Loading branch information
honzasterba committed Feb 22, 2021
1 parent 3831cbd commit 594e872
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,29 +67,6 @@ static Booster loadModel(String modelPath) throws XGBoostError {
return ret;
}

/**
* Load a new Booster model from a file opened as input stream.
* The assumption is the input stream only contains one XGBoost Model.
* This can be used to load existing booster models saved by other xgboost bindings.
*
* @param in The input stream of the file.
* @return The created boosted
* @throws XGBoostError
* @throws IOException
*/
static Booster loadModel(InputStream in) throws XGBoostError, IOException {
int size;
byte[] buf = new byte[1<<20];
ByteArrayOutputStream os = new ByteArrayOutputStream();
while ((size = in.read(buf)) != -1) {
os.write(buf, 0, size);
}
in.close();
Booster ret = new Booster(new HashMap<>(), new DMatrix[0]);
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray()));
return ret;
}

/**
* Load a new Booster model from a byte array buffer.
* The assumption is the array only contains one XGBoost Model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
*/
package ml.dmlc.xgboost4j.java;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.*;
import java.util.*;

import org.apache.commons.logging.Log;
Expand Down Expand Up @@ -56,9 +53,15 @@ public static Booster loadModel(String modelPath)
* @throws XGBoostError
* @throws IOException
*/
public static Booster loadModel(InputStream in)
throws XGBoostError, IOException {
return Booster.loadModel(in);
public static Booster loadModel(InputStream in) throws XGBoostError, IOException {
int size;
byte[] buf = new byte[1<<20];
ByteArrayOutputStream os = new ByteArrayOutputStream();
while ((size = in.read(buf)) != -1) {
os.write(buf, 0, size);
}
in.close();
return Booster.loadModel(buf);
}

/**
Expand All @@ -70,8 +73,7 @@ public static Booster loadModel(InputStream in)
* @return The create boosted
* @throws XGBoostError
*/
public static Booster loadModel(byte[] buffer)
throws XGBoostError, IOException {
public static Booster loadModel(byte[] buffer) throws XGBoostError, IOException {
return Booster.loadModel(buffer);
}

Expand Down

0 comments on commit 594e872

Please sign in to comment.