Skip to content

Commit

Permalink
[jvm-packages] Don't cast to float if it's already float (#10386)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 4, 2024
1 parent 9b7633c commit bc7643d
Showing 1 changed file with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,13 @@ private[spark] object GpuUtils {
val featureNameSet = featureNames.distinct
validateSchema(dataset.schema, featureNameSet, labelName, weightName, marginName, fitting)

val castToFloat = (ds: Dataset[_], colName: String) => {
val colMeta = ds.schema(colName).metadata
ds.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
val castToFloat = (df: DataFrame, colName: String) => {
if (df.schema(colName).dataType.isInstanceOf[FloatType]) {
df
} else {
val colMeta = df.schema(colName).metadata
df.withColumn(colName, col(colName).as(colName, colMeta).cast(FloatType))
}
}
val colNames = if (fitting) {
var names = featureNameSet :+ labelName
Expand Down

0 comments on commit bc7643d

Please sign in to comment.