Skip to content

ClassCastException with LgbmRanker train #823

@arijeetm1

Description

@arijeetm1

Describe the bug
java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.Double on training.

To Reproduce

model = LightGBMRanker(
                boostingType = 'gbdt',
                objective = 'lambdarank',
                maxPosition=50,
                isProvideTrainingMetric=True,
                maxBin = 255,
                evalAt=[10],
                numIterations = 500,
                learningRate = 0.3,
                numLeaves = 127,
                earlyStoppingRound = 20,
                #parallelism = 'serial',
                #num_threads = 8
                featureFraction = 0.5,
                baggingFreq = 1,
                baggingFraction = 0.8,
                #min_data_in_leaf = 20 
                minSumHessianInLeaf = 0.001,
                #is_enable_sparse = True,
                #use_two_round_loading = True,
                #is_save_binary_file = False,
                groupCol='label',
                labelGain=[0.0,1.0,3.0,7.0,15.0],
                categoricalSlotIndexes=[4,5,6,7,8,9,10,11,12,13,14,15,16]
).fit(data)

data.schema
StructType(List(StructField(label,DoubleType,true),StructField(features,VectorUDT,true)))

Expected behavior
A clear and concise description of what you expected to happen.

Info (please complete the following information):

  • MMLSpark Version: v1.0.0-rc1
  • Spark Version 2.4.2
  • Spark Platform Dataproc

** Stacktrace**
Spark history server logs:

java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.Double
	at scala.runtime.BoxesRunTime.unboxToDouble(BoxesRunTime.java:114)
	at org.apache.spark.sql.Row$class.getDouble(Row.scala:248)
	at org.apache.spark.sql.catalyst.expressions.GenericRow.getDouble(rows.scala:166)
	at com.microsoft.ml.spark.lightgbm.TrainUtils$$anonfun$3.apply(TrainUtils.scala:29)
	at com.microsoft.ml.spark.lightgbm.TrainUtils$$anonfun$3.apply(TrainUtils.scala:29)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
	at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
	at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
	at com.microsoft.ml.spark.lightgbm.TrainUtils$.generateDataset(TrainUtils.scala:29)
	at com.microsoft.ml.spark.lightgbm.TrainUtils$.translate(TrainUtils.scala:233)
	at com.microsoft.ml.spark.lightgbm.TrainUtils$.trainLightGBM(TrainUtils.scala:385)
	at com.microsoft.ml.spark.lightgbm.LightGBMBase$$anonfun$6.apply(LightGBMBase.scala:145)
	at com.microsoft.ml.spark.lightgbm.LightGBMBase$$anonfun$6.apply(LightGBMBase.scala:145)
	at org.apache.spark.sql.execution.MapPartitionsExec$$anonfun$5.apply(objects.scala:188)
	at org.apache.spark.sql.execution.MapPartitionsExec$$anonfun$5.apply(objects.scala:185)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:858)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:858)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.sql.execution.SQLExecutionRDD.compute(SQLExecutionRDD.scala:55)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:123)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

If the bug pertains to a specific feature please tag the appropriate CODEOWNER for better visibility

Additional context
We could rule out this code where we are getDouble from label : https://github.com/Azure/mmlspark/blob/master/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala#L98 since value is of type vector.

Looking through the spark sql query plan gives us an possible explanation with labels, where we cast the label to int during projection and then attempt to cast it to double during deserialize which could be related to this issue.
Screen Shot 2020-03-10 at 11 20 18 PM

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions