Skip to content

Commit

Permalink
fix: improve error message for invalid slot names (#897)
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Aug 12, 2020
1 parent 95c1f8a commit 96f0b77
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.spark.core.utils.ClusterUtil
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.param.shared.{HasFeaturesCol => HasFeaturesColSpark, HasLabelCol => HasLabelColSpark}
import org.apache.spark.ml.{Estimator, Model}
Expand All @@ -14,6 +15,7 @@ import scala.concurrent.Await
import scala.concurrent.duration.{Duration, SECONDS}
import scala.language.existentials
import scala.math.min
import scala.util.matching.Regex

trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[TrainedModel]
with LightGBMParams with HasFeaturesColSpark with HasLabelColSpark {
Expand Down Expand Up @@ -156,6 +158,25 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
categoricalSlotIndexesArr, categoricalSlotNamesArr)
}

private def validateSlotNames(df: DataFrame, columnParams: ColumnParams, trainParams: TrainParams): Unit = {
val schema = df.schema
val featuresSchema = schema.fields(schema.fieldIndex(getFeaturesCol))
val metadata = AttributeGroup.fromStructField(featuresSchema)
if (metadata.attributes.isDefined) {
val slotNamesOpt = TrainUtils.getSlotNames(df.schema,
columnParams.featuresColumn, metadata.attributes.get.length, trainParams)
val pattern = new Regex("[\",:\\[\\]{}]")
slotNamesOpt.foreach(slotNames => {
val badSlotNames = slotNames.flatMap(slotName =>
if (pattern.findFirstIn(slotName).isEmpty) None else Option(slotName))
if (!badSlotNames.isEmpty) {
val errorMsg = s"Invalid slot names detected in features column: ${badSlotNames.mkString(",")}"
throw new IllegalArgumentException(errorMsg)
}
})
}
}

/**
* Inner train method for LightGBM learners. Calculates the number of workers,
* creates a driver thread, and runs mapPartitions on the dataset.
Expand Down Expand Up @@ -199,6 +220,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
val preprocessedDF = preprocessData(trainingData)
val schema = preprocessedDF.schema
val columnParams = ColumnParams(getLabelCol, getFeaturesCol, get(weightCol), get(initScoreCol), getOptGroupCol)
validateSlotNames(preprocessedDF, columnParams, trainParams)
val mapPartitionsFunc = TrainUtils.trainLightGBM(batchIndex, networkParams, columnParams, validationData, log,
trainParams, numTasksPerExec, schema)(_)
val lightGBMBooster =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ class VerifyLightGBMRegressor extends Benchmarks
assert(metric < 0.6)
}

test("Verify LightGBM Regressor with bad column names fails early") {
val baseModelWithBadSlots = baseModel.setSlotNames(Range(0, 22).map(i =>
"Invalid characters \",:[]{} " + i).toArray)
interceptWithoutLogging[IllegalArgumentException]{baseModelWithBadSlots.fit(flareDF).transform(flareDF).collect()}
}

test("Verify LightGBM Regressor with tweedie distribution") {
val model = baseModel.setObjective("tweedie").setTweedieVariancePower(1.5)

Expand Down

0 comments on commit 96f0b77

Please sign in to comment.