Skip to content

Commit

Permalink
perf: improve lightgbm training performance 4x-10x by setting num_thr…
Browse files Browse the repository at this point in the history
…eads to be cores-1
  • Loading branch information
imatiach-msft committed Dec 2, 2021
1 parent 3898ad9 commit 6fb7fc4
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,12 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
*
* @return ExecutionParams object containing parameters related to LightGBM execution.
*/
protected def getExecutionParams: ExecutionParams = {
ExecutionParams(getChunkSize, getMatrixType, getNumThreads, getUseSingleDatasetMode)
protected def getExecutionParams(numTasksPerExec: Int): ExecutionParams = {
val execNumThreads =
if (getUseSingleDatasetMode) get(numThreads).getOrElse(numTasksPerExec - 1)
else getNumThreads

ExecutionParams(getChunkSize, getMatrixType, execNumThreads, getUseSingleDatasetMode)
}

protected def getColumnParams: ColumnParams = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class LightGBMClassifier(override val uid: String)
getIsUnbalance, getVerbosity, categoricalIndexes, actualNumClasses, getBoostFromAverage,
getBoostingType, get(lambdaL1), get(lambdaL2), get(isProvideTrainingMetric),
get(metric), get(minGainToSplit), get(maxDeltaStep), getMaxBinByFeature, get(minDataInLeaf), getSlotNames,
getDelegate, getDartParams, getExecutionParams, getObjectiveParams)
getDelegate, getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams)
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMClassificationModel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class LightGBMRanker(override val uid: String)
getVerbosity, categoricalIndexes, getBoostingType, get(lambdaL1), get(lambdaL2), getMaxPosition, getLabelGain,
get(isProvideTrainingMetric), get(metric), getEvalAt, get(minGainToSplit), get(maxDeltaStep),
getMaxBinByFeature, get(minDataInLeaf), getSlotNames, getDelegate, getDartParams,
getExecutionParams, getObjectiveParams)
getExecutionParams(numTasksPerExec), getObjectiveParams)
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRankerModel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class LightGBMRegressor(override val uid: String)
getBoostFromAverage, getBoostingType, get(lambdaL1), get(lambdaL2), get(isProvideTrainingMetric),
get(metric), get(minGainToSplit), get(maxDeltaStep),
getMaxBinByFeature, get(minDataInLeaf), getSlotNames, getDelegate,
getDartParams, getExecutionParams, getObjectiveParams)
getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams)
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRegressionModel = {
Expand Down

0 comments on commit 6fb7fc4

Please sign in to comment.