Skip to content

Commit

Permalink
[JVM] Add Iterator loading API
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Mar 5, 2016
1 parent 770b345 commit 86871d4
Show file tree
Hide file tree
Showing 10 changed files with 451 additions and 5 deletions.
62 changes: 62 additions & 0 deletions include/xgboost/c_api.h
Expand Up @@ -12,6 +12,9 @@
#endif

// XGBoost C API will include APIs in Rabit C API
XGB_EXTERN_C {
#include <stdio.h>
}
#include <rabit/c_api.h>

#if defined(_MSC_VER) || defined(_WIN32)
Expand All @@ -26,6 +29,51 @@ typedef unsigned long bst_ulong; // NOLINT(*)
typedef void *DMatrixHandle;
/*! \brief handle to Booster */
typedef void *BoosterHandle;
/*! \brief handle to a data iterator */
typedef void *DataIterHandle;
/*! \brief handle to a internal data holder. */
typedef void *DataHolderHandle;

/*! \brief Mini batch used in XGBoost Data Iteration */
typedef struct {
/*! \brief number of rows in the minibatch */
size_t size;
/*! \brief row pointer to the rows in the data */
long* offset; // NOLINT(*)
/*! \brief labels of each instance */
float* label;
/*! \brief weight of each instance, can be NULL */
float* weight;
/*! \brief feature index */
int* index;
/*! \brief feature values */
float* value;
} XGBoostBatchCSR;


/*!
* \brief Callback to set the data to handle,
* \param handle The handle to the callback.
* \param batch The data content to be setted.
*/
XGB_EXTERN_C typedef int XGBCallbackSetData(
DataHolderHandle handle, XGBoostBatchCSR batch);

/*!
* \brief The data reading callback function.
* The iterator will be able to give subset of batch in the data.
*
* If there is data, the function will call set_function to set the data.
*
* \param data_handle The handle to the callback.
* \param set_function The batch returned by the iterator
* \param set_function_handle The handle to be passed to set function.
* \return 0 if we are reaching the end and batch is not returned.
*/
XGB_EXTERN_C typedef int XGBCallbackDataIterNext(
DataIterHandle data_handle,
XGBCallbackSetData* set_function,
DataHolderHandle set_function_handle);

/*!
* \brief get string message of the last error
Expand All @@ -50,6 +98,20 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
int silent,
DMatrixHandle *out);

/*!
* \brief Create a DMatrix from a data iterator.
* \param data_handle The handle to the data.
* \param callback The callback to get the data.
* \param cache_info Additional information about cache file, can be null.
* \param out The created DMatrix
* \return 0 when success, -1 when failure happens.
*/
XGB_DLL int XGDMatrixCreateFromDataIter(
DataIterHandle data_handle,
XGBCallbackDataIterNext* callback,
const char* cache_info,
DMatrixHandle *out);

/*!
* \brief create a matrix content from csr format
* \param indptr pointer to row headers
Expand Down
Expand Up @@ -16,6 +16,7 @@
package ml.dmlc.xgboost4j;

import java.io.IOException;
import java.util.Iterator;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -47,6 +48,33 @@ public static enum SparseType {
CSC;
}

/**
* Create DMatrix from iterator.
*
* @param iter The data iterator of mini batch to provide the data.
* @param cache_info Cache path information, used for external memory setting, can be null.
* @throws XGBoostError
*/
public DMatrix(Iterator<DataBatch> iter, String cache_info) throws XGBoostError {
if (iter == null) {
throw new NullPointerException("iter: null");
}
try {
logger.info(iter.getClass().getMethod("next").toString());
} catch(NoSuchMethodException e) {
logger.info(e.toString());
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
handle = out[0];
}

/**
* Create DMatrix by loading libsvm file from dataPath
*
* @param dataPath The path to the data.
* @throws XGBoostError
*/
public DMatrix(String dataPath) throws XGBoostError {
if (dataPath == null) {
throw new NullPointerException("dataPath: null");
Expand All @@ -56,6 +84,14 @@ public DMatrix(String dataPath) throws XGBoostError {
handle = out[0];
}

/**
* Create DMatrix from Sparse matrix in CSR/CSC format.
* @param headers The row index of the matrix.
* @param indices The indices of presenting entries.
* @param data The data content.
* @param st Type of sparsity.
* @throws XGBoostError
*/
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
Expand Down
@@ -0,0 +1,43 @@
package ml.dmlc.xgboost4j;

/**
* A mini-batch of data that can be converted to DMatrix.
* The data is in sparse matrix CSR format.
*
* Usually this object is not needed.
*
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
*/
public class DataBatch {
/** The offset of each rows in the sparse matrix */
long[] rowOffset = null;
/** weight of each data point, can be null */
float[] weight = null;
/** label of each data point, can be null */
float[] label = null;
/** index of each feature(column) in the sparse matrix */
int[] featureIndex = null;
/** value of each non-missing entry in the sparse matrix */
float[] featureValue = null;
/**
* Get number of rows in the data batch.
* @return Number of rows in the data batch.
*/
public int numRows() {
return rowOffset.length - 1;
}

/**
* Shallow copy a DataBatch
* @return a copy of the batch
*/
public DataBatch shallowCopy() {
DataBatch b = new DataBatch();
b.rowOffset = this.rowOffset;
b.weight = this.weight;
b.label = this.label;
b.featureIndex = this.featureIndex;
b.featureValue = this.featureValue;
return b;
}
}
Expand Up @@ -15,6 +15,7 @@
*/
package ml.dmlc.xgboost4j;


/**
* xgboost JNI functions
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
Expand All @@ -26,6 +27,8 @@ class XgboostJNI {

public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);

public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter, String cache_info, long[] out);

public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
long[] out);

Expand Down
141 changes: 139 additions & 2 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Expand Up @@ -20,13 +20,124 @@
#include <vector>
#include <string>

//helper functions
//set handle
// helper functions
// set handle
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
long out = (long) handle;
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
}

// global JVM
static JavaVM* global_jvm = nullptr;

// overrides JNI on load
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
global_jvm = vm;
return JNI_VERSION_1_6;
}

XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
DataIterHandle data_handle,
XGBCallbackSetData* set_function,
DataHolderHandle set_function_handle) {
jobject jiter = static_cast<jobject>(data_handle);
JNIEnv* jenv;
int jni_status = global_jvm->GetEnv((void **)&jenv, JNI_VERSION_1_6);
if (jni_status == JNI_EDETACHED) {
global_jvm->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
} else {
CHECK(jni_status == JNI_OK);
}
try {
jclass iterClass = jenv->FindClass("java/util/Iterator");
jmethodID hasNext = jenv->GetMethodID(iterClass,
"hasNext", "()Z");
jmethodID next = jenv->GetMethodID(iterClass,
"next", "()Ljava/lang/Object;");
int ret_value;
if (jenv->CallBooleanMethod(jiter, hasNext)) {
ret_value = 1;
jobject batch = jenv->CallObjectMethod(jiter, next);
jclass batchClass = jenv->GetObjectClass(batch);
jlongArray joffset = (jlongArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "rowOffset", "[J"));
jfloatArray jlabel = (jfloatArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "label", "[F"));
jfloatArray jweight = (jfloatArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "weight", "[F"));
jintArray jindex = (jintArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "featureIndex", "[I"));
jfloatArray jvalue = (jfloatArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "featureValue", "[F"));
XGBoostBatchCSR cbatch;
cbatch.size = jenv->GetArrayLength(joffset) - 1;
cbatch.offset = jenv->GetLongArrayElements(joffset, 0);
if (jlabel != nullptr) {
cbatch.label = jenv->GetFloatArrayElements(jlabel, 0);
CHECK_EQ(jenv->GetArrayLength(jlabel), static_cast<long>(cbatch.size))
<< "batch.label.length must equal batch.numRows()";
} else {
cbatch.label = nullptr;
}
if (jweight != nullptr) {
cbatch.weight = jenv->GetFloatArrayElements(jweight, 0);
CHECK_EQ(jenv->GetArrayLength(jweight), static_cast<long>(cbatch.size))
<< "batch.weight.length must equal batch.numRows()";
} else {
cbatch.weight = nullptr;
}
long max_elem = cbatch.offset[cbatch.size];
cbatch.index = jenv->GetIntArrayElements(jindex, 0);
cbatch.value = jenv->GetFloatArrayElements(jvalue, 0);
CHECK_EQ(jenv->GetArrayLength(jindex), max_elem)
<< "batch.index.length must equal batch.offset.back()";
CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem)
<< "batch.index.length must equal batch.offset.back()";
// cbatch is ready
CHECK_EQ((*set_function)(set_function_handle, cbatch), 0)
<< XGBGetLastError();
// release the elements.
jenv->ReleaseLongArrayElements(joffset, cbatch.offset, 0);
jenv->DeleteLocalRef(joffset);
if (jlabel != nullptr) {
jenv->ReleaseFloatArrayElements(jlabel, cbatch.label, 0);
jenv->DeleteLocalRef(jlabel);
}
if (jweight != nullptr) {
jenv->ReleaseFloatArrayElements(jweight, cbatch.weight, 0);
jenv->DeleteLocalRef(jweight);
}
jenv->ReleaseIntArrayElements(jindex, cbatch.index, 0);
jenv->DeleteLocalRef(jindex);
jenv->ReleaseFloatArrayElements(jvalue, cbatch.value, 0);
jenv->DeleteLocalRef(jvalue);
jenv->DeleteLocalRef(batch);
jenv->DeleteLocalRef(batchClass);
ret_value = 1;
} else {
ret_value = 0;
}
jenv->DeleteLocalRef(iterClass);
// only detach if it is a async call.
if (jni_status == JNI_EDETACHED) {
global_jvm->DetachCurrentThread();
}
return ret_value;
} catch(dmlc::Error e) {
// only detach if it is a async call.
if (jni_status == JNI_EDETACHED) {
global_jvm->DetachCurrentThread();
}
LOG(FATAL) << e.what();
return -1;
}
}

/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBGetLastError
* Signature: ()Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
(JNIEnv *jenv, jclass jcls) {
jstring jresult = 0;
Expand All @@ -37,6 +148,32 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
return jresult;
}

/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromDataIter
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter
(JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) {
DMatrixHandle result;
const char* cache_info = nullptr;
if (jcache_info != nullptr) {
cache_info = jenv->GetStringUTFChars(jcache_info, 0);
}
int ret = XGDMatrixCreateFromDataIter(
jiter, XGBoost4jCallbackDataIterNext, cache_info, &result);
if (cache_info) {
jenv->ReleaseStringUTFChars(jcache_info, cache_info);
}
setHandle(jenv, jout, result);
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromFile
* Signature: (Ljava/lang/String;I[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
DMatrixHandle result;
Expand Down
8 changes: 8 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 86871d4

Please sign in to comment.