Skip to content

Commit

Permalink
test out old sparse dataset create method
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Jun 1, 2021
1 parent 9e9ff1a commit 8511d54
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 15 deletions.
21 changes: 21 additions & 0 deletions src/main/scala/com/microsoft/lightgbm/CSRUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.lightgbm

/** Temporary class that accepts int32_t arrays instead of void pointer arguments.
* TODO: Need to generate a new lightGBM jar with utility to convert int array
* to void pointer and then remove this file.
*/
object CSRUtils {
// scalastyle:off parameter.number
def LGBM_DatasetCreateFromCSR(var0: SWIGTYPE_p_int, var1: Int, var2: SWIGTYPE_p_int, var3: SWIGTYPE_p_void,
var4: Int, var5: Int, var6: Int,
var7: Int, var8: String, var9: SWIGTYPE_p_void,
var10: SWIGTYPE_p_p_void): Int = {
lightgbmlibJNI.LGBM_DatasetCreateFromCSR(SWIGTYPE_p_int.getCPtr(var0), var1, SWIGTYPE_p_int.getCPtr(var2),
SWIGTYPE_p_void.getCPtr(var3), var4, var5, var6,
var7, var8, SWIGTYPE_p_void.getCPtr(var9), SWIGTYPE_p_p_void.getCPtr(var10))
}
// scalastyle:on parameter.number
}
76 changes: 61 additions & 15 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,29 @@ object LightGBMUtils {
dataset
}


def newDoubleArray(array: Array[Double]): (SWIGTYPE_p_void, SWIGTYPE_p_double) = {
val data = lightgbmlib.new_doubleArray(array.length)
array.zipWithIndex.foreach {
case (value, index) => lightgbmlib.doubleArray_setitem(data, index, value)
}
(lightgbmlib.double_to_voidp_ptr(data), data)
}

def newIntArray(array: Array[Int]): (SWIGTYPE_p_int, SWIGTYPE_p_int) = {
val data = lightgbmlib.new_intArray(array.length)
array.zipWithIndex.foreach {
case (value, index) => lightgbmlib.intArray_setitem(data, index, value)
}
(lightgbmlib.int_to_int32_t_ptr(data), data)
}

def intToPtr(value: Int): SWIGTYPE_p_long = {
val longPtr = lightgbmlib.new_longp()
lightgbmlib.longp_assign(longPtr, value)
longPtr
}

/** Generates a sparse dataset in CSR format.
*
* @param sparseRows The rows of sparse vector.
Expand All @@ -264,21 +287,44 @@ object LightGBMUtils {
referenceDataset: Option[LightGBMDataset],
featureNamesOpt: Option[Array[String]],
trainParams: TrainParams): LightGBMDataset = {
val numCols = sparseRows(0).size
var values: Option[(SWIGTYPE_p_void, SWIGTYPE_p_double)] = None
var indexes: Option[(SWIGTYPE_p_int, SWIGTYPE_p_int)] = None
var indptrNative: Option[(SWIGTYPE_p_int, SWIGTYPE_p_int)] = None
try {
val valuesArray = sparseRows.flatMap(_.values)
values = Some(newDoubleArray(valuesArray))
val indexesArray = sparseRows.flatMap(_.indices)
indexes = Some(newIntArray(indexesArray))
val indptr = new Array[Int](sparseRows.length + 1)
sparseRows.zipWithIndex.foreach {
case (row, index) => indptr(index + 1) = indptr(index) + row.numNonzeros
}
indptrNative = Some(newIntArray(indptr))
val numCols = sparseRows(0).size

val datasetOutPtr = lightgbmlib.voidpp_handle()
val datasetParams = s"max_bin=${trainParams.maxBin} is_pre_partition=True " +
(if (trainParams.categoricalFeatures.isEmpty) ""
else s"categorical_feature=${trainParams.categoricalFeatures.mkString(",")}")
// Generate the dataset for features
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromCSRSpark(
sparseRows.asInstanceOf[Array[Object]],
sparseRows.length,
numCols, datasetParams, referenceDataset.map(_.datasetPtr).orNull,
datasetOutPtr),
"Dataset create")
val dataset = new LightGBMDataset(lightgbmlib.voidpp_value(datasetOutPtr))
dataset.setFeatureNames(featureNamesOpt, numCols)
dataset
val datasetOutPtr = lightgbmlib.voidpp_handle()
val datasetParams = s"max_bin=${trainParams.maxBin} is_pre_partition=True " +
s"bin_construct_sample_cnt=${trainParams.binSampleCount} " +
(if (trainParams.categoricalFeatures.isEmpty) ""
else s"categorical_feature=${trainParams.categoricalFeatures.mkString(",")}")
val dataInt32bitType = lightgbmlibConstants.C_API_DTYPE_INT32
val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64
// Generate the dataset for features
LightGBMUtils.validate(CSRUtils.LGBM_DatasetCreateFromCSR(
indptrNative.get._1, dataInt32bitType,
indexes.get._1, values.get._1, data64bitType,
indptr.length, valuesArray.length,
numCols, datasetParams, referenceDataset.map(_.datasetPtr).orNull,
datasetOutPtr),
"Dataset create")
val dataset = new LightGBMDataset(lightgbmlib.voidpp_value(datasetOutPtr))
dataset.setFeatureNames(featureNamesOpt, numCols)
dataset
} finally {
// Delete the input rows
if (values.isDefined) lightgbmlib.delete_doubleArray(values.get._2)
if (indexes.isDefined) lightgbmlib.delete_intArray(indexes.get._2)
if (indptrNative.isDefined) lightgbmlib.delete_intArray(indptrNative.get._2)
}
}
}

0 comments on commit 8511d54

Please sign in to comment.