diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala new file mode 100644 index 0000000000000..f3ce6dfca2c1c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.BinaryAttribute +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructType} + +/** + * :: AlphaComponent :: + * Binarize a column of continuous features given a threshold. + */ +@AlphaComponent +final class Binarizer extends Transformer with HasInputCol with HasOutputCol { + + /** + * Param for threshold used to binarize continuous features. + * The features greater than the threshold, will be binarized to 1.0. + * The features equal to or less than the threshold, will be binarized to 0.0. + * @group param + */ + val threshold: DoubleParam = + new DoubleParam(this, "threshold", "threshold used to binarize continuous features") + + /** @group getParam */ + def getThreshold: Double = getOrDefault(threshold) + + /** @group setParam */ + def setThreshold(value: Double): this.type = set(threshold, value) + + setDefault(threshold -> 0.0) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val td = map(threshold) + val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 } + val outputColName = map(outputCol) + val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata() + dataset.select(col("*"), + binarizer(col(map(inputCol))).as(outputColName, metadata)) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType) + + val inputFields = schema.fields + val outputColName = map(outputCol) + + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + + val attr = BinaryAttribute.defaultAttr.withName(outputColName) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala new file mode 100644 index 0000000000000..caf1b759593f3 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + + +class BinarizerSuite extends FunSuite with MLlibTestSparkContext { + + @transient var data: Array[Double] = _ + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4) + } + + test("Binarize continuous features with default parameter") { + val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) + val dataFrame: DataFrame = sqlContext.createDataFrame( + data.zip(defaultBinarized)).toDF("feature", "expected") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + + binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, "The feature value is not correct after binarization.") + } + } + + test("Binarize continuous features with setter") { + val threshold: Double = 0.2 + val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) + val dataFrame: DataFrame = sqlContext.createDataFrame( + data.zip(thresholdBinarized)).toDF("feature", "expected") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(threshold) + + binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, "The feature value is not correct after binarization.") + } + } +}