Skip to content

Commit

Permalink
Fix issues with merge conflict resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored and joanfontanals committed Jan 30, 2020
1 parent d5d05f8 commit c7f6a3d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 48 deletions.
103 changes: 56 additions & 47 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBooster.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class LightGBMBooster(val model: String) extends Serializable {

lazy val numClasses: Int = getNumClasses()

lazy val numFeatures: Int = getNumFeatures()
lazy val numFeatures: Int = getNumFeatures

lazy val numTotalModel: Int = getNumTotalModel

Expand All @@ -79,7 +79,7 @@ class LightGBMBooster(val model: String) extends Serializable {

@transient
var shapDataLengthLongPtr: SWIGTYPE_p_long_long = _

@transient
var leafIndexDataOutPtr: SWIGTYPE_p_double = _

Expand Down Expand Up @@ -125,27 +125,6 @@ class LightGBMBooster(val model: String) extends Serializable {
lightgbmlib.delete_doubleArray(shapDataOutPtr)
}

protected def shapForCSR(sparseVector: SparseVector): Array[Double] = {
val numCols = sparseVector.size
val kind = lightgbmlibConstants.C_API_PREDICT_CONTRIB

val datasetParams = "max_bin=255 is_pre_partition=True"
val dataInt32bitType = lightgbmlibConstants.C_API_DTYPE_INT32
val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64

ensureShapDataCreated()

LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterPredictForCSRSingle(
sparseVector.indices, sparseVector.values,
sparseVector.numNonzeros,
boosterPtr, dataInt32bitType, data64bitType, 2, numCols,
kind, -1, datasetParams,
shapDataLengthLongPtr, shapDataOutPtr), "Booster Predict")

predToArray(false, shapDataOutPtr, kind)
}

protected def predictScoreForCSR(sparseVector: SparseVector, kind: Int, classification: Boolean): Array[Double] = {
val numCols = sparseVector.size

Expand All @@ -166,27 +145,6 @@ class LightGBMBooster(val model: String) extends Serializable {
predScoreToArray(classification, scoredDataOutPtr, kind)
}

protected def shapForMat(row: Array[Double]): Array[Double] = {
val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64
val kind = lightgbmlibConstants.C_API_PREDICT_CONTRIB

val numCols = row.length
val isRowMajor = 1

val datasetParams = "max_bin=255"

ensureShapDataCreated()

LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterPredictForMatSingle(
row, boosterPtr, data64bitType,
numCols,
isRowMajor, kind,
-1, datasetParams, shapDataLengthLongPtr, shapDataOutPtr),
"Booster Predict")
predToArray(false, shapDataOutPtr, kind)
}

protected def predictScoreForMat(row: Array[Double], kind: Int, classification: Boolean): Array[Double] = {
val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64

Expand Down Expand Up @@ -248,6 +206,48 @@ class LightGBMBooster(val model: String) extends Serializable {
predLeafToArray(leafIndexDataOutPtr)
}

protected def shapForCSR(sparseVector: SparseVector): Array[Double] = {
val numCols = sparseVector.size
val kind = lightgbmlibConstants.C_API_PREDICT_CONTRIB

val datasetParams = "max_bin=255 is_pre_partition=True"
val dataInt32bitType = lightgbmlibConstants.C_API_DTYPE_INT32
val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64

ensureShapDataCreated()

LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterPredictForCSRSingle(
sparseVector.indices, sparseVector.values,
sparseVector.numNonzeros,
boosterPtr, dataInt32bitType, data64bitType, 2, numCols,
kind, -1, datasetParams,
shapDataLengthLongPtr, shapDataOutPtr), "Booster Predict")

shapToArray(shapDataOutPtr)
}

protected def shapForMat(row: Array[Double]): Array[Double] = {
val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64
val kind = lightgbmlibConstants.C_API_PREDICT_CONTRIB

val numCols = row.length
val isRowMajor = 1

val datasetParams = "max_bin=255"

ensureShapDataCreated()

LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterPredictForMatSingle(
row, boosterPtr, data64bitType,
numCols,
isRowMajor, kind,
-1, datasetParams, shapDataLengthLongPtr, shapDataOutPtr),
"Booster Predict")
shapToArray(shapDataOutPtr)
}

def saveNativeModel(session: SparkSession, filename: String, overwrite: Boolean): Unit = {
if (filename == null || filename.isEmpty) {
throw new IllegalArgumentException("filename should not be empty or null.")
Expand Down Expand Up @@ -311,14 +311,18 @@ class LightGBMBooster(val model: String) extends Serializable {
lightgbmlib.intp_value(numClassesOut)
}

def getNumFeatures(): Int = {
def getNumFeatures: Int = {
if (boosterPtr == null) {
LightGBMUtils.initializeNativeLibrary()
boosterPtr = getBoosterPtrFromModelString(model)
}
val numFeaturesOut = lightgbmlib.new_intp()
LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterGetNumFeature(boosterPtr, numFeaturesOut),
"Booster NumFeature")
lightgbmlib.intp_value(numFeaturesOut)
}
}

/**
* Retrieve the number of models per each iteration from LightGBM Booster
* @return The number of models per iteration.
Expand Down Expand Up @@ -375,4 +379,9 @@ class LightGBMBooster(val model: String) extends Serializable {
(0 until numTotalModel).map(modelNum =>
lightgbmlib.doubleArray_getitem(leafIndexDataOutPtr, modelNum)).toArray
}

private def shapToArray(shapDataOutPtr: SWIGTYPE_p_double): Array[Double] = {
(0 until numFeatures).map(featNum =>
lightgbmlib.doubleArray_getitem(shapDataOutPtr, featNum)).toArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class LightGBMRankerModel(override val uid: String, override val model: LightGBM
override def objectsToSave: List[Any] =
List(uid, model, getLabelCol, getFeaturesCol, getPredictionCol)

override def numFeatures: Int = model.getNumFeatures()
override def numFeatures: Int = model.getNumFeatures

def saveNativeModel(filename: String, overwrite: Boolean): Unit = {
val session = SparkSession.builder().getOrCreate()
Expand Down

0 comments on commit c7f6a3d

Please sign in to comment.