Skip to content

Commit

Permalink
Add clean missing data module
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Jun 15, 2017
1 parent 016883d commit ee412f5
Show file tree
Hide file tree
Showing 4 changed files with 364 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/clean-missing-data/build.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
//> DependsOn: core
//> DependsOn: utils
208 changes: 208 additions & 0 deletions src/clean-missing-data/src/main/scala/CleanMissingData.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark

import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml._
import org.apache.spark.sql._
import org.apache.spark.sql.types.StructType

import scala.collection.mutable.ListBuffer

object CleanMissingData extends DefaultParamsReadable[CleanMissingData] {
val meanOpt = "Mean"
val medianOpt = "Median"
val customOpt = "Custom"
val modes = Array(meanOpt, medianOpt, customOpt)

def validateAndTransformSchema(schema: StructType,
inputCols: Array[String],
outputCols: Array[String]): StructType = {
inputCols.zip(outputCols).foldLeft(schema)((oldSchema, io) => {
if (oldSchema.fieldNames.contains(io._2)) {
val index = oldSchema.fieldIndex(io._2)
val fields = oldSchema.fields
fields(index) = oldSchema.fields(oldSchema.fieldIndex(io._1))
StructType(fields)
} else {
oldSchema.add(oldSchema.fields(oldSchema.fieldIndex(io._1)))
}
})
}
}

/**
* Removes missing values from input dataset.
* The following modes are supported:
* Mean - replaces missings with mean of fit column
* Median - replaces missings with approximate median of fit column
* Custom - replaces missings with custom value specified by user
*/
class CleanMissingData(override val uid: String) extends Estimator[CleanMissingDataModel]
with HasInputCols with HasOutputCols with MMLParams {

def this() = this(Identifiable.randomUID("CleanMissingData"))

val cleaningMode = StringParam(this, "cleaningMode", "cleaning mode", CleanMissingData.meanOpt)
def setCleaningMode(value: String): this.type = set(cleaningMode, value)
def getCleaningMode: String = $(cleaningMode)

val customValue = DoubleParam(this, "customValue", "custom value for replacement")
def setCustomValue(value: Double): this.type = set(customValue, value)
def getCustomValue: Double = $(customValue)

/**
* Fits the dataset, prepares the transformation function.
*
* @param dataset The input dataset.
* @return The model for removing missings.
*/
override def fit(dataset: Dataset[_]): CleanMissingDataModel = {
val replacementValues = getReplacementValues(dataset, getInputCols, getOutputCols, getCleaningMode)
new CleanMissingDataModel(uid, replacementValues, getInputCols, getOutputCols)
}

override def copy(extra: ParamMap): Estimator[CleanMissingDataModel] = defaultCopy(extra)

@DeveloperApi
override def transformSchema(schema: StructType): StructType =
CleanMissingData.validateAndTransformSchema(schema, getInputCols, getOutputCols)

private def getReplacementValues(dataset: Dataset[_],
colsToClean: Array[String],
outputCols: Array[String],
mode: String): Map[String, Any] = {
import org.apache.spark.sql.functions._
val columns = colsToClean.map(col => dataset(col))
val metrics = getCleaningMode match {
case CleanMissingData.meanOpt => {
val row = dataset.select(columns.map(column => avg(column)): _*).collect()(0)
rowToValues(row)
}
case CleanMissingData.medianOpt => {
val row = dataset.select(columns.map(column => callUDF("percentile_approx", column, lit(0.5))): _*).collect()(0)
rowToValues(row)
}
case CleanMissingData.customOpt => {
colsToClean.map(col => getCustomValue)
}
}
outputCols.zip(metrics).toMap
}

private def rowToValues(row: Row): Array[Double] = {
val avgs = ListBuffer[Double]()
for (i <- 0 until row.size) {
avgs += row.getDouble(i)
}
avgs.toArray
}
}

/**
* Model produced by [[CleanMissingData]].
*/
class CleanMissingDataModel(val uid: String,
val replacementValues: Map[String, Any],
val inputCols: Array[String],
val outputCols: Array[String])
extends Model[CleanMissingDataModel] with MLWritable {

override def write: MLWriter = new CleanMissingDataModel.CleanMissingDataModelWriter(uid,
replacementValues,
inputCols,
outputCols)

override def copy(extra: ParamMap): CleanMissingDataModel =
new CleanMissingDataModel(uid, replacementValues, inputCols, outputCols)

override def transform(dataset: Dataset[_]): DataFrame = {
val datasetCols = dataset.columns.map(name => dataset(name)).toList
val datasetInputCols = inputCols.zip(outputCols)
.flatMap(io =>
if (io._1 == io._2) {
None
} else {
Some(dataset(io._1).as(io._2))
}).toList
val addedCols = dataset.select((datasetCols ::: datasetInputCols):_*)
addedCols.na.fill(replacementValues)
}

@DeveloperApi
override def transformSchema(schema: StructType): StructType =
CleanMissingData.validateAndTransformSchema(schema, inputCols, outputCols)
}

object CleanMissingDataModel extends MLReadable[CleanMissingDataModel] {

private val replacementValuesPart = "replacementValues"
private val inputColsPart = "inputCols"
private val outputColsPart = "outputCols"
private val dataPart = "data"

override def read: MLReader[CleanMissingDataModel] = new CleanMissingDataModelReader

override def load(path: String): CleanMissingDataModel = super.load(path)

/** [[MLWriter]] instance for [[CleanMissingDataModel]] */
private[CleanMissingDataModel]
class CleanMissingDataModelWriter(val uid: String,
val replacementValues: Map[String, Any],
val inputCols: Array[String],
val outputCols: Array[String])
extends MLWriter {
private case class Data(uid: String)

override protected def saveImpl(path: String): Unit = {
val overwrite = this.shouldOverwrite
val qualPath = PipelineUtilities.makeQualifiedPath(sc, path)
// Required in order to allow this to be part of an ML pipeline
PipelineUtilities.saveMetadata(uid,
CleanMissingDataModel.getClass.getName.replace("$", ""),
new Path(path, "metadata").toString,
sc,
overwrite)

// save the replacement values
ObjectUtilities.writeObject(replacementValues, qualPath, replacementValuesPart, sc, overwrite)

// save the input cols and output cols
ObjectUtilities.writeObject(inputCols, qualPath, inputColsPart, sc, overwrite)
ObjectUtilities.writeObject(outputCols, qualPath, outputColsPart, sc, overwrite)

// save model data
val data = Data(uid)
val dataPath = new Path(qualPath, dataPart).toString
val saveMode =
if (overwrite) SaveMode.Overwrite
else SaveMode.ErrorIfExists
sparkSession.createDataFrame(Seq(data)).repartition(1).write.mode(saveMode).parquet(dataPath)
}
}

private class CleanMissingDataModelReader
extends MLReader[CleanMissingDataModel] {

override def load(path: String): CleanMissingDataModel = {
val qualPath = PipelineUtilities.makeQualifiedPath(sc, path)
// load the uid
val dataPath = new Path(qualPath, dataPart).toString
val data = sparkSession.read.format("parquet").load(dataPath)
val Row(uid: String) = data.select("uid").head()

// get the replacement values
val replacementValues = ObjectUtilities.loadObject[Map[String, Any]](qualPath, replacementValuesPart, sc)
// get the input and output cols
val inputCols = ObjectUtilities.loadObject[Array[String]](qualPath, inputColsPart, sc)
val outputCols = ObjectUtilities.loadObject[Array[String]](qualPath, outputColsPart, sc)

new CleanMissingDataModel(uid, replacementValues, inputCols, outputCols)
}
}

}
142 changes: 142 additions & 0 deletions src/clean-missing-data/src/test/scala/VerifyCleanMissingData.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark

import org.apache.spark.ml.Estimator
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
import java.lang.{Double => JDouble, Integer => JInt}

import org.scalactic.TolerantNumerics

/**
* Tests to validate the functionality of Clean Missing Data estimator.
*/
class VerifyCleanMissingData extends EstimatorFuzzingTest {

val tolerance = 0.01
implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(tolerance)
val tolEq = TolerantNumerics.tolerantDoubleEquality(tolerance)

import session.implicits._
def createMockDataset: DataFrame = {
Seq[(JInt, JInt, JDouble, JDouble, JInt)](
(0, 2, 0.50, 0.60, 0),
(1, 3, 0.40, null, null),
(0, 4, 0.78, 0.99, 2),
(1, 5, 0.12, 0.34, 3),
(0, 1, 0.50, 0.60, 0),
(null, null, null, null, null),
(0, 3, 0.78, 0.99, 2),
(1, 4, 0.12, 0.34, 3),
(0, null, 0.50, 0.60, 0),
(1, 2, 0.40, 0.50, null),
(0, 3, null, 0.99, 2),
(1, 4, 0.12, 0.34, 3))
.toDF("col1", "col2", "col3", "col4", "col5")
}

test("Test for cleaning missing data with mean") {
val dataset = createMockDataset
val cmd = new CleanMissingData()
.setInputCols(dataset.columns)
.setOutputCols(dataset.columns)
.setCleaningMode(CleanMissingData.meanOpt)
val cmdModel = cmd.fit(dataset)
val result = cmdModel.transform(dataset)
// Calculate mean of column values
val numCols = dataset.columns.length
val meanValues = Array.ofDim[Double](numCols)
val counts = Array.ofDim[Double](numCols)
val collected = dataset.collect()
collected.foreach(row => {
for (i <- 0 until numCols) {
val rawValue = row.get(i)
val rowValue =
if (rawValue == null) 0
else if (i == 2 || i == 3) {
counts(i) += 1
row.get(i).asInstanceOf[JDouble].doubleValue()
} else {
counts(i) += 1
row.get(i).asInstanceOf[JInt].doubleValue()
}
meanValues(i) += rowValue
}
})
for (i <- 0 until numCols) {
meanValues(i) /= counts(i)
if (i != 2 && i != 3) {
meanValues(i) = meanValues(i).toInt.toDouble
}
}
verifyReplacementValues(dataset, result, meanValues)
}

test("Test for cleaning missing data with median") {
val dataset = createMockDataset
val cmd = new CleanMissingData()
.setInputCols(dataset.columns)
.setOutputCols(dataset.columns)
.setCleaningMode(CleanMissingData.medianOpt)
val cmdModel = cmd.fit(dataset)
val result = cmdModel.transform(dataset)
val medianValues = Array[Double](0, 3, 0.4, 0.6, 2)
verifyReplacementValues(dataset, result, medianValues)
}

test("Test for cleaning missing data with custom value") {
val dataset = createMockDataset
val customValue = -1.5
val cmd = new CleanMissingData()
.setInputCols(dataset.columns)
.setOutputCols(dataset.columns)
.setCleaningMode(CleanMissingData.customOpt)
.setCustomValue(customValue)
val cmdModel = cmd.fit(dataset)
val result = cmdModel.transform(dataset)
val replacesValues = Array.fill[Double](dataset.columns.length)(customValue)
val numCols = replacesValues.length
for (i <- 0 until numCols) {
if (i != 2 && i != 3) {
replacesValues(i) = replacesValues(i).toInt.toDouble
}
}
verifyReplacementValues(dataset, result, replacesValues)
}

private def verifyReplacementValues(expected: DataFrame, result: DataFrame, expectedValues: Array[Double]) = {
val collectedExp = expected.collect()
val collectedResult = result.collect()
val numRows = result.count().toInt
val numCols = result.columns.length
for (j <- 0 until numRows) {
for (i <- 0 until numCols) {
val row = collectedExp(j)
val (rowValue, actualValue) =
if (i == 2 || i == 3) {
(row.get(i).asInstanceOf[JDouble], collectedResult(j)(i).asInstanceOf[Double])
} else {
(row.get(i).asInstanceOf[JInt], collectedResult(j)(i).asInstanceOf[Int].toDouble)
}
if (rowValue == null) {
val expectedValue = expectedValues(i)
assert(tolEq.areEquivalent(expectedValue, actualValue),
s"Values do not match, expected: $expectedValue, result: $actualValue")
}
}
}
}

override def createFitDataset: DataFrame = {
createMockDataset
}

override def schemaForDataset: StructType = ???

override def getEstimator(): Estimator[_] = {
val dataset = createFitDataset
new CleanMissingData().setInputCols(dataset.columns).setOutputCols(dataset.columns)
}
}
12 changes: 12 additions & 0 deletions src/core/contracts/src/main/scala/Params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ trait HasOutputCol extends Wrappable {
def getOutputCol: String = $(outputCol)
}

trait HasInputCols extends Wrappable {
val inputCols = new StringArrayParam(this, "inputCols", "The names of the input columns")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
def getInputCols: Array[String] = $(inputCols)
}

trait HasOutputCols extends Wrappable {
val outputCols = new StringArrayParam(this, "outputCols", "The names of the output columns")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
def getOutputCols: Array[String] = $(outputCols)
}

trait HasLabelCol extends Wrappable {
val labelCol = StringParam(this, "labelCol", "The name of the label column")
def setLabelCol(value: String): this.type = set(labelCol, value)
Expand Down

0 comments on commit ee412f5

Please sign in to comment.