Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor to have separate dataset utils and partition processor #1089

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.spark.core.utils.ClusterUtil
import com.microsoft.ml.spark.io.http.SharedSingleton
import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.params.{DartModeParams, ExecutionParams, LightGBMParams,
ObjectiveParams, TrainParams}
import com.microsoft.ml.spark.lightgbm.dataset.DatasetUtils
import com.microsoft.ml.spark.lightgbm.params.{DartModeParams, ExecutionParams, LightGBMParams, ObjectiveParams,
TrainParams}
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
Expand Down Expand Up @@ -169,7 +171,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
val featuresSchema = schema.fields(schema.fieldIndex(getFeaturesCol))
val metadata = AttributeGroup.fromStructField(featuresSchema)
if (metadata.attributes.isDefined) {
val slotNamesOpt = TrainUtils.getSlotNames(df.schema,
val slotNamesOpt = DatasetUtils.getSlotNames(df.schema,
columnParams.featuresColumn, metadata.attributes.get.length, trainParams)
val pattern = new Regex("[\",:\\[\\]{}]")
slotNamesOpt.foreach(slotNames => {
Expand Down Expand Up @@ -251,8 +253,8 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
val schema = preprocessedDF.schema
val columnParams = ColumnParams(getLabelCol, getFeaturesCol, get(weightCol), get(initScoreCol), getOptGroupCol)
validateSlotNames(preprocessedDF, columnParams, trainParams)
val mapPartitionsFunc = TrainUtils.trainLightGBM(batchIndex, networkParams, columnParams, validationData, log,
trainParams, numTasksPerExec, schema)(_)
val mapPartitionsFunc = PartitionProcessor.trainLightGBM(batchIndex, networkParams, columnParams,
validationData, log, trainParams, numTasksPerExec, schema)(_)
val lightGBMBooster =
if (getUseBarrierExecutionMode) {
preprocessedDF.rdd.barrier().mapPartitions(mapPartitionsFunc).reduce((booster1, _) => booster1)
Expand All @@ -278,6 +280,9 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine

/** Gets the training parameters.
*
* @param numTasks The total number of tasks.
* @param categoricalIndexes The indexes of the categorical slots in the features vector.
* @param dataset The training dataset.
* @return train parameters.
*/
protected def getTrainParams(numTasks: Int, categoricalIndexes: Array[Int], dataset: Dataset[_]): TrainParams
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.lightgbm.lightgbmlib
import com.microsoft.ml.spark.lightgbm.TrainUtils.{afterGenerateTrainDataset, afterGenerateValidDataset,
beforeGenerateTrainDataset, beforeGenerateValidDataset, createBooster, getNetworkInfo, getReturnBooster,
networkInit, trainCore}
import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.dataset.{DatasetUtils, LightGBMDataset}
import com.microsoft.ml.spark.lightgbm.params.TrainParams
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.slf4j.Logger

object PartitionProcessor {
def trainLightGBM(batchIndex: Int, networkParams: NetworkParams, columnParams: ColumnParams,
validationData: Option[Broadcast[Array[Row]]], log: Logger,
trainParams: TrainParams, numTasksPerExec: Int, schema: StructType)
(inputRows: Iterator[Row]): Iterator[LightGBMBooster] = {
val emptyPartition = !inputRows.hasNext
val isEnabledWorker = !emptyPartition
// Initialize the native library
LightGBMUtils.initializeNativeLibrary()
// Initialize the network communication
val (nodes, localListenPort) = getNetworkInfo(networkParams, numTasksPerExec, log, isEnabledWorker)
if (emptyPartition) {
log.warn("LightGBM task encountered empty partition, for best performance ensure no partitions empty")
List[LightGBMBooster]().toIterator
} else {
log.info(s"LightGBM task listening on: $localListenPort")
// Return booster only from main worker to reduce network communication overhead
val returnBooster = getReturnBooster(isEnabledWorker, nodes, log, numTasksPerExec, localListenPort)
try {
// If worker enabled, initialize the network ring of communication
networkInit(nodes, localListenPort, log, LightGBMConstants.NetworkRetries, LightGBMConstants.InitialDelay)
translate(batchIndex, columnParams, validationData, log, trainParams, returnBooster, schema, inputRows)
} finally {
// Finalize network when done
if (isEnabledWorker) LightGBMUtils.validate(lightgbmlib.LGBM_NetworkFree(), "Finalize network")
}
}
}

def translate(batchIndex: Int, columnParams: ColumnParams, validationData: Option[Broadcast[Array[Row]]],
log: Logger, trainParams: TrainParams, returnBooster: Boolean,
schema: StructType, inputRows: Iterator[Row]): Iterator[LightGBMBooster] = {
var trainDatasetOpt: Option[LightGBMDataset] = None
var validDatasetOpt: Option[LightGBMDataset] = None
try {
beforeGenerateTrainDataset(batchIndex, columnParams, schema, log, trainParams)
trainDatasetOpt = DatasetUtils.generateDataset(inputRows, columnParams, None, schema,
log, trainParams)
afterGenerateTrainDataset(batchIndex, columnParams, schema, log, trainParams)

if (validationData.isDefined) {
beforeGenerateValidDataset(batchIndex, columnParams, schema, log, trainParams)
validDatasetOpt = DatasetUtils.generateDataset(validationData.get.value.toIterator, columnParams,
trainDatasetOpt, schema, log, trainParams)
afterGenerateValidDataset(batchIndex, columnParams, schema, log, trainParams)
}

var boosterOpt: Option[LightGBMBooster] = None
try {
val booster = createBooster(trainParams, trainDatasetOpt.get, validDatasetOpt)
boosterOpt = Some(booster)
val bestIterResult = trainCore(batchIndex, trainParams, booster, log, validDatasetOpt.isDefined)
if (returnBooster) {
val model = booster.saveToString()
val modelBooster = new LightGBMBooster(model)
// Set best iteration on booster if hit early stopping criteria in trainCore
bestIterResult.foreach(modelBooster.setBestIteration(_))
Iterator.single(modelBooster)
} else {
Iterator.empty
}
} finally {
// Free booster
boosterOpt.foreach(_.freeNativeMemory())
}
} finally {
// Free datasets
trainDatasetOpt.foreach(_.close())
validDatasetOpt.foreach(_.close())
}
}
}
Loading