Skip to content

Commit

Permalink
[SWPRIVATE-16] NA handling for Spark algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
mdymczyk committed Sep 6, 2016
1 parent 5cdd792 commit 488648b
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 71 deletions.
9 changes: 8 additions & 1 deletion ml/src/main/scala/hex/schemas/SVMV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package hex.schemas;

import org.apache.spark.ml.spark.models.MissingValuesHandling;
import org.apache.spark.ml.spark.models.svm.*;
import water.DKV;
import water.Key;
Expand Down Expand Up @@ -51,7 +52,8 @@ public static final class SVMParametersV3 extends
"gradient",

"ignored_columns",
"ignore_const_cols"
"ignore_const_cols",
"missing_values_handling"
};

@API(help="Initial model weights.", direction=API.Direction.INOUT, gridable = true)
Expand Down Expand Up @@ -82,6 +84,11 @@ public static final class SVMParametersV3 extends
@API(help="Set the gradient computation type for SGD.", direction=API.Direction.INPUT, values = {"Hinge", "LeastSquares", "Logistic"}, required = true, gridable = true, level = API.Level.expert)
public Gradient gradient = Gradient.Hinge;

@API(level = API.Level.expert, direction = API.Direction.INOUT, gridable = true,
values = {"NotAllowed", "Skip", "MeanImputation"},
help = "Handling of missing values. Either NotAllowed, Skip or MeanImputation.")
public MissingValuesHandling missing_values_handling;

@Override
public SVMParametersV3 fillFromImpl(SVMParameters impl) {
super.fillFromImpl(impl);
Expand Down
70 changes: 70 additions & 0 deletions ml/src/main/scala/org/apache/spark/ml/FrameMLUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml

import org.apache.spark.h2o.H2OContext
import org.apache.spark.ml.spark.models.MissingValuesHandling
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.{DataTypes, StructField}
import water.fvec.{Frame, H2OFrame}

object FrameMLUtils {
def toLabeledPoints(parms: Frame,
_response_column: String,
nfeatures: Int,
means: Array[Double],
missingHandler: MissingValuesHandling,
h2oContext: H2OContext,
sqlContext: SQLContext): RDD[LabeledPoint] = {
val domains = parms.domains()

val trainingDF = h2oContext.asDataFrame(new H2OFrame(parms))(sqlContext)
val fields: Array[StructField] = trainingDF.schema.fields
var trainingRDD = trainingDF.rdd

if (MissingValuesHandling.Skip.eq(missingHandler)) {
trainingRDD = trainingRDD.filter(_.anyNull)
} else if(MissingValuesHandling.MeanImputation.eq(missingHandler)) {
(0 until nfeatures).
foreach( i => means(i) = trainingRDD.map(row => toDouble(row.get(i),fields(i),domains(i))).mean())
}

trainingRDD.map(row => {
val features = new Array[Double](nfeatures)
(0 until nfeatures).foreach(i => features(i) = if(row.isNullAt(i)) means(i) else toDouble(row.get(i), fields(i), domains(i)))

new LabeledPoint(
toDouble(row.getAs[String](_response_column), fields(fields.length - 1), domains(domains.length - 1)),
Vectors.dense(features)
)
})
}

private def toDouble(value: Any, fieldStruct: StructField, domain: Array[String]): Double = {
fieldStruct.dataType match {
case DataTypes.ByteType => value.asInstanceOf[Byte].doubleValue
case DataTypes.ShortType => value.asInstanceOf[Short].doubleValue
case DataTypes.IntegerType => value.asInstanceOf[Integer].doubleValue
case DataTypes.DoubleType => value.asInstanceOf[Double]
case DataTypes.StringType => domain.indexOf(value)
case _ => throw new IllegalArgumentException("Target column has to be an enum or a number. " + fieldStruct.toString)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package org.apache.spark.ml.spark.models;

public enum MissingValuesHandling {
NotAllowed, Skip, MeanImputation
}
88 changes: 19 additions & 69 deletions ml/src/main/scala/org/apache/spark/ml/spark/models/svm/SVM.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,19 @@

import hex.*;

import org.apache.spark.api.java.function.Function;
import org.apache.spark.SparkContext;
import org.apache.spark.h2o.H2OContext;
import org.apache.spark.ml.FrameMLUtils;
import org.apache.spark.ml.spark.models.MissingValuesHandling;
import org.apache.spark.ml.spark.models.svm.SVMModel.SVMOutput;
import org.apache.spark.mllib.classification.SVMWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import water.DKV;
import water.fvec.Frame;
import water.fvec.H2OFrame;
import water.fvec.Vec;
import water.util.Log;

Expand Down Expand Up @@ -95,11 +92,13 @@ public void init(boolean expensive) {
}
}

for (int i = 0; i < _train.numCols(); i++) {
Vec vec = _train.vec(i);
String vecName = _train.name(i);
if (vec.naCnt() > 0 && (null == _parms._ignored_columns || Arrays.binarySearch(_parms._ignored_columns, vecName) < 0)) {
error("_train", "Training frame cannot contain any missing values [" + vecName + "].");
if(MissingValuesHandling.NotAllowed.equals(_parms._missing_values_handling)) {
for (int i = 0; i < _train.numCols(); i++) {
Vec vec = _train.vec(i);
String vecName = _train.name(i);
if (vec.naCnt() > 0 && (null == _parms._ignored_columns || Arrays.binarySearch(_parms._ignored_columns, vecName) < 0)) {
error("_train", "Training frame cannot contain any missing values [" + vecName + "].");
}
}
}

Expand All @@ -109,7 +108,7 @@ public void init(boolean expensive) {
for (int i = 0; i < _train.vecs().length; i++) {
Vec vec = _train.vec(i);
if (!ignoredCols.contains(_train.name(i)) && !(vec.isNumeric() || vec.isCategorical())) {
error("_train", "SVM supports only frames with numeric values (except for result column). But a " + vec.get_type_str() + " was found.");
error("_train", "SVM supports only frames with numeric/categorical values (except for result column). But a " + vec.get_type_str() + " was found.");
}
}

Expand Down Expand Up @@ -169,10 +168,15 @@ public void computeImpl() {
SVMModel model = new SVMModel(dest(), _parms, new SVMModel.SVMOutput(SVM.this));
model.delete_and_lock(_job);

RDD<LabeledPoint> training = getTrainingData(
double[] means = new double[model._output.nfeatures()];
RDD<LabeledPoint> training = FrameMLUtils.toLabeledPoints(
_train,
_parms._response_column,
model._output.nfeatures()
model._output.nfeatures(),
means,
_parms._missing_values_handling,
h2oContext,
sqlContext
);
training.cache();

Expand All @@ -198,6 +202,8 @@ public void computeImpl() {
model._output.iterations_$eq(_parms._max_iterations);
model._output.interceptor_$eq(trainedModel.intercept());

model._output.numMeans_$eq(means);

Frame train = DKV.<Frame>getGet(_parms._train);
model.score(train).delete();
model._output._training_metrics = ModelMetrics.getFromDKV(model, train);
Expand Down Expand Up @@ -225,61 +231,5 @@ private Vector vec2vec(Vec[] vals) {
}
return Vectors.dense(dense);
}

private RDD<LabeledPoint> getTrainingData(Frame parms, String _response_column, int nfeatures) {
return h2oContext.asDataFrame(new H2OFrame(parms), true, sqlContext)
.javaRDD()
.map(new RowToLabeledPoint(nfeatures, _response_column, parms.domains())).rdd();
}
}
}

class RowToLabeledPoint implements Function<Row, LabeledPoint> {
private final int nfeatures;
private final String _response_column;
private final String[][] domains;

RowToLabeledPoint(int nfeatures, String response_column, String[][] domains) {
this.nfeatures = nfeatures;
this._response_column = response_column;
this.domains = domains;
}

@Override
public LabeledPoint call(Row row) throws Exception {
StructField[] fields = row.schema().fields();
double[] features = new double[nfeatures];
for (int i = 0; i < nfeatures; i++) {
features[i] = toDouble(row.get(i), fields[i], domains[i]);
}

return new LabeledPoint(
toDouble(row.<String>getAs(_response_column), fields[fields.length - 1], domains[domains.length - 1]),
Vectors.dense(features));
}

private double toDouble(Object value, StructField fieldStruct, String[] domain) {
if (fieldStruct.dataType().sameType(DataTypes.ByteType)) {
return ((Byte) value).doubleValue();
}

if (fieldStruct.dataType().sameType(DataTypes.ShortType)) {
return ((Short) value).doubleValue();
}

if (fieldStruct.dataType().sameType(DataTypes.IntegerType)) {
return ((Integer) value).doubleValue();
}

if (fieldStruct.dataType().sameType(DataTypes.DoubleType)) {
return (Double) value;
}

if (fieldStruct.dataType().sameType(DataTypes.StringType)) {
return Arrays.binarySearch(domain, value);
}

throw new IllegalArgumentException("Target column has to be an enum or a number. " + fieldStruct.toString());
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
*/
package org.apache.spark.ml.spark.models.svm

import java.lang

import hex.ModelMetricsSupervised.MetricBuilderSupervised
import hex._
import org.apache.spark.ml.spark.models.MissingValuesHandling
import water.codegen.CodeGeneratorPipeline
import water.util.{JCodeGen, SBPrintStream}
import water.{H2O, Key, Keyed}
Expand All @@ -28,6 +31,7 @@ object SVMModel {
var interceptor: Double = .0
var iterations: Int = 0
var weights: Array[Double] = null
var numMeans: Array[Double] = null
}

}
Expand All @@ -49,8 +53,15 @@ class SVMModel private[svm](val selfKey: Key[_ <: Keyed[_ <: Keyed[_ <: AnyRef]]

protected def score0(data: Array[Double], preds: Array[Double]): Array[Double] = {
java.util.Arrays.fill(preds, 0)

val pred =
data.zip(_output.weights).foldRight(_output.interceptor){ case ((d, w), acc) => d * w + acc}
data.zip(_output.weights).foldRight(_output.interceptor){ case ((d, w), acc) => {
if(MissingValuesHandling.MeanImputation.eq(_parms._missing_values_handling) && lang.Double.isNaN(d)) {
_output.numMeans(0) * w + acc
} else {
d * w + acc
}
}}

if(_parms._threshold.isNaN) { // Regression
preds(0) = pred
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.apache.spark.ml.spark.models.svm;

import hex.Model;
import org.apache.spark.ml.spark.models.MissingValuesHandling;
import water.Key;
import water.fvec.Frame;

Expand Down Expand Up @@ -35,6 +36,7 @@ public final Frame initialWeights() {
public Updater _updater = Updater.L2;
public Gradient _gradient = Gradient.Hinge;
public Key<Frame> _initial_weights = null;
public MissingValuesHandling _missing_values_handling = MissingValuesHandling.MeanImputation;

public void validate(SVM svm) {
if (_max_iterations < 0 || _max_iterations > 1e6) {
Expand Down

0 comments on commit 488648b

Please sign in to comment.