Skip to content

Commit

Permalink
[SWPRIVATE-16] NA handling for Spark algorithms (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdymczyk authored and jakubhava committed Dec 15, 2016
1 parent a06c0ca commit b6102f8
Show file tree
Hide file tree
Showing 13 changed files with 415 additions and 140 deletions.
25 changes: 1 addition & 24 deletions core/build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
description = "Sparkling Water Core"

apply from: "$rootDir/gradle/utils.gradle"
apply from: "$rootDir/gradle/sparkTest.gradle"

dependencies {
// Required for h2o-app (we need UI)
Expand Down Expand Up @@ -54,30 +55,6 @@ dependencies {
integTestRuntime fileTree(dir: new File((String) sparkHome, "lib/"), include: '*.jar' )
}

// Setup test environment for Spark
test {
// Test environment
systemProperty "spark.testing", "true"
systemProperty "spark.ext.h2o.node.log.dir", new File(project.getBuildDir(), "h2ologs-test/nodes")
systemProperty "spark.ext.h2o.client.log.dir", new File(project.getBuildDir(), "h2ologs-test/client")

// Run with assertions ON
enableAssertions = true

// For a new JVM for each test class
forkEvery = 1

// Increase heap size
maxHeapSize = "4g"

// Increase PermGen
jvmArgs '-XX:MaxPermSize=384m'

// Working dir will be root project
workingDir = rootDir
// testLogging.showStandardStreams = true
}

task createSparkVersionFile << {
File version_file = file("src/main/resources/spark.version")
// Create parent directories if not created yet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,21 +543,6 @@ class DataFrameConverterTest extends FunSuite with SharedSparkTestContext {
h2oFrameEnum.delete()
}

def makeH2OFrame[T: ClassTag](fname: String, colNames: Array[String], chunkLayout: Array[Long],
data: Array[Array[T]], h2oType: Byte, colDomains: Array[Array[String]] = null): H2OFrame = {
var f: Frame = new Frame(Key.make(fname))
FrameUtils.preparePartialFrame(f,colNames)
f.update()

for( i <- chunkLayout.indices) { buildChunks(fname, data(i), i, Array(h2oType)) }

f = DKV.get(fname).get()

FrameUtils.finalizePartialFrame(f, chunkLayout, colDomains, Array(h2oType))

new H2OFrame(f)
}

def fp(it:Iterator[Row]):Unit = {
println(it.size)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.SparkContext
import org.apache.spark.h2o.{H2OConf, H2OContext, Holder}
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.scalatest.Suite
import water.fvec.{Chunk, FrameUtils, NewChunk, Vec}
import water.fvec._
import water.{DKV, Key}
import water.parser.BufferedString

import scala.reflect.ClassTag
Expand All @@ -32,9 +32,10 @@ import scala.reflect.ClassTag
* Helper trait to simplify initialization and termination of Spark/H2O contexts.
*
*/
trait SharedSparkTestContext extends SparkTestContext { self: Suite =>
trait SharedSparkTestContext extends SparkTestContext {
self: Suite =>

def createSparkContext:SparkContext
def createSparkContext: SparkContext

def createH2OContext(sc: SparkContext, conf: H2OConf): H2OContext = {
H2OContext.getOrCreate(sc, conf)
Expand All @@ -51,26 +52,51 @@ trait SharedSparkTestContext extends SparkTestContext { self: Suite =>
super.afterAll()
}

def buildChunks[T: ClassTag](fname: String, data: Array[T], cidx: Integer, h2oType: Array[Byte]): Chunk = {
def makeH2OFrame[T: ClassTag](fname: String, colNames: Array[String], chunkLayout: Array[Long],
data: Array[Array[T]], h2oType: Byte, colDomains: Array[Array[String]] = null): H2OFrame = {
makeH2OFrame2(fname, colNames, chunkLayout, data.map(_.map(value => Array(value))), Array(h2oType), colDomains)
}

def makeH2OFrame2[T: ClassTag](fname: String, colNames: Array[String], chunkLayout: Array[Long],
data: Array[Array[Array[T]]], h2oTypes: Array[Byte], colDomains: Array[Array[String]] = null): H2OFrame = {
var f: Frame = new Frame(Key.make(fname))
FrameUtils.preparePartialFrame(f, colNames)
f.update()

for (i <- chunkLayout.indices) {
buildChunks(fname, data(i), i, h2oTypes)
}

f = DKV.get(fname).get()

FrameUtils.finalizePartialFrame(f, chunkLayout, colDomains, h2oTypes)

new H2OFrame(f)
}

def buildChunks[T: ClassTag](fname: String, data: Array[Array[T]], cidx: Integer, h2oType: Array[Byte]): Array[_ <: Chunk] = {
val nchunks: Array[NewChunk] = FrameUtils.createNewChunks(fname, h2oType, cidx)

val chunk: NewChunk = nchunks(0)
data.foreach {
case u: UUID => chunk.addUUID(
u.getLeastSignificantBits,
u.getMostSignificantBits)
case s: String => chunk.addStr(new BufferedString(s))
case b: Byte => chunk.addNum(b)
case s: Short => chunk.addNum(s)
case c: Integer if h2oType(0) == Vec.T_CAT => chunk.addCategorical(c)
case i: Integer if h2oType(0) != Vec.T_CAT => chunk.addNum(i.toDouble)
case l: Long => chunk.addNum(l)
case d: Double => chunk.addNum(d)
case x =>
throw new IllegalArgumentException(s"Failed to figure out what is it: $x")
data.foreach { values =>
values.indices.foreach { idx =>
val chunk: NewChunk = nchunks(idx)
values(idx) match {
case null => chunk.addNA()
case u: UUID => chunk.addUUID(u.getLeastSignificantBits, u.getMostSignificantBits)
case s: String => chunk.addStr(new BufferedString(s))
case b: Byte => chunk.addNum(b)
case s: Short => chunk.addNum(s)
case c: Integer if h2oType(0) == Vec.T_CAT => chunk.addCategorical(c)
case i: Integer if h2oType(0) != Vec.T_CAT => chunk.addNum(i.toDouble)
case l: Long => chunk.addNum(l)
case d: Double => chunk.addNum(d)
case x =>
throw new IllegalArgumentException(s"Failed to figure out what is it: $x")
}
}
}
FrameUtils.closeNewChunks(nchunks)
chunk
nchunks
}
}

Expand Down
23 changes: 23 additions & 0 deletions gradle/sparkTest.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Setup test environment for Spark
test {
// Test environment
systemProperty "spark.testing", "true"
systemProperty "spark.ext.h2o.node.log.dir", new File(project.getBuildDir(), "h2ologs-test/nodes")
systemProperty "spark.ext.h2o.client.log.dir", new File(project.getBuildDir(), "h2ologs-test/client")

// Run with assertions ON
enableAssertions = true

// For a new JVM for each test class
forkEvery = 1

// Increase heap size
maxHeapSize = "4g"

// Increase PermGen
jvmArgs '-XX:MaxPermSize=384m'

// Working dir will be root project
workingDir = rootDir
// testLogging.showStandardStreams = true
}
5 changes: 4 additions & 1 deletion ml/build.gradle
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
apply from: "$rootDir/gradle/sparkTest.gradle"

description = "Sparkling Water ML Pipelines"

dependencies {
Expand Down Expand Up @@ -36,4 +38,5 @@ sourceSets {
srcDirs = []
}
}
}
}

9 changes: 8 additions & 1 deletion ml/src/main/java/hex/schemas/SVMV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package hex.schemas;

import org.apache.spark.ml.spark.models.MissingValuesHandling;
import org.apache.spark.ml.spark.models.svm.*;
import water.DKV;
import water.Key;
Expand Down Expand Up @@ -51,7 +52,8 @@ public static final class SVMParametersV3 extends
"gradient",

"ignored_columns",
"ignore_const_cols"
"ignore_const_cols",
"missing_values_handling"
};

@API(help="Initial model weights.", direction=API.Direction.INOUT, gridable = true)
Expand Down Expand Up @@ -82,6 +84,11 @@ public static final class SVMParametersV3 extends
@API(help="Set the gradient computation type for SGD.", direction=API.Direction.INPUT, values = {"Hinge", "LeastSquares", "Logistic"}, required = true, gridable = true, level = API.Level.expert)
public Gradient gradient = Gradient.Hinge;

@API(level = API.Level.expert, direction = API.Direction.INOUT, gridable = true,
values = {"NotAllowed", "Skip", "MeanImputation"},
help = "Handling of missing values. Either NotAllowed, Skip or MeanImputation.")
public MissingValuesHandling missing_values_handling;

@Override
public SVMParametersV3 fillFromImpl(SVMParameters impl) {
super.fillFromImpl(impl);
Expand Down
98 changes: 26 additions & 72 deletions ml/src/main/java/org/apache/spark/ml/spark/models/svm/SVM.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,24 @@

import hex.*;

import org.apache.spark.api.java.function.Function;
import org.apache.spark.SparkContext;
import org.apache.spark.h2o.H2OContext;
import org.apache.spark.ml.spark.ProgressListener;
import org.apache.spark.ml.FrameMLUtils;
import org.apache.spark.ml.spark.models.MissingValuesHandling;
import org.apache.spark.ml.spark.models.svm.SVMModel.SVMOutput;
import org.apache.spark.mllib.classification.SVMWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.storage.RDDInfo;

import scala.Tuple2;
import water.DKV;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.H2OFrame;
import water.fvec.Vec;
import water.util.Log;

Expand Down Expand Up @@ -100,11 +99,13 @@ public void init(boolean expensive) {
}
}

for (int i = 0; i < _train.numCols(); i++) {
Vec vec = _train.vec(i);
String vecName = _train.name(i);
if (vec.naCnt() > 0 && (null == _parms._ignored_columns || Arrays.binarySearch(_parms._ignored_columns, vecName) < 0)) {
error("_train", "Training frame cannot contain any missing values [" + vecName + "].");
if(MissingValuesHandling.NotAllowed == _parms._missing_values_handling) {
for (int i = 0; i < _train.numCols(); i++) {
Vec vec = _train.vec(i);
String vecName = _train.name(i);
if (vec.naCnt() > 0 && (null == _parms._ignored_columns || Arrays.binarySearch(_parms._ignored_columns, vecName) < 0)) {
error("_train", "Training frame cannot contain any missing values [" + vecName + "].");
}
}
}

Expand All @@ -114,7 +115,7 @@ public void init(boolean expensive) {
for (int i = 0; i < _train.vecs().length; i++) {
Vec vec = _train.vec(i);
if (!ignoredCols.contains(_train.name(i)) && !(vec.isNumeric() || vec.isCategorical())) {
error("_train", "SVM supports only frames with numeric values (except for result column). But a " + vec.get_type_str() + " was found.");
error("_train", "SVM supports only frames with numeric/categorical values (except for result column). But a " + vec.get_type_str() + " was found.");
}
}

Expand Down Expand Up @@ -175,18 +176,26 @@ public void computeImpl() {
try {
model.delete_and_lock(_job);

RDD<LabeledPoint> training = getTrainingData(
Tuple2<RDD<LabeledPoint>, double[]> points = FrameMLUtils.toLabeledPoints(
_train,
_parms._response_column,
model._output.nfeatures()
model._output.nfeatures(),
_parms._missing_values_handling,
h2oContext,
sqlContext
);
RDD<LabeledPoint> training = points._1();
training.cache();

if(training.count() == 0 &&
MissingValuesHandling.Skip == _parms._missing_values_handling) {
throw new H2OIllegalArgumentException("No rows left in the dataset after filtering out rows with missing values. Ignore columns with many NAs or set missing_values_handling to 'MeanImputation'.");
}


SVMWithSGD svm = new SVMWithSGD();
svm.setIntercept(_parms._add_intercept);

svm.optimizer().setNumIterations(_parms._max_iterations);

svm.optimizer().setStepSize(_parms._step_size);
svm.optimizer().setRegParam(_parms._reg_param);
svm.optimizer().setMiniBatchFraction(_parms._mini_batch_fraction);
Expand All @@ -207,13 +216,14 @@ public void computeImpl() {
svm.run(training, vec2vec(_parms.initialWeights().vecs()));
training.unpersist(false);


sc.listenerBus().listeners().remove(progressBar);

model._output.weights_$eq(trainedModel.weights().toArray());
model._output.iterations_$eq(_parms._max_iterations);
model._output.interceptor_$eq(trainedModel.intercept());

model._output.numMeans_$eq(points._2());

Frame train = DKV.<Frame>getGet(_parms._train);
model.score(train).delete();
model._output._training_metrics = ModelMetrics.getFromDKV(model, train);
Expand Down Expand Up @@ -243,61 +253,5 @@ private Vector vec2vec(Vec[] vals) {
}
return Vectors.dense(dense);
}

private RDD<LabeledPoint> getTrainingData(Frame parms, String _response_column, int nfeatures) {
return h2oContext.asDataFrame(new H2OFrame(parms), true, sqlContext)
.javaRDD()
.map(new RowToLabeledPoint(nfeatures, _response_column, parms.domains())).rdd();
}
}
}

class RowToLabeledPoint implements Function<Row, LabeledPoint> {
private final int nfeatures;
private final String _response_column;
private final String[][] domains;

RowToLabeledPoint(int nfeatures, String response_column, String[][] domains) {
this.nfeatures = nfeatures;
this._response_column = response_column;
this.domains = domains;
}

@Override
public LabeledPoint call(Row row) throws Exception {
StructField[] fields = row.schema().fields();
double[] features = new double[nfeatures];
for (int i = 0; i < nfeatures; i++) {
features[i] = toDouble(row.get(i), fields[i], domains[i]);
}

return new LabeledPoint(
toDouble(row.<String>getAs(_response_column), fields[fields.length - 1], domains[domains.length - 1]),
Vectors.dense(features));
}

private double toDouble(Object value, StructField fieldStruct, String[] domain) {
if (fieldStruct.dataType().sameType(DataTypes.ByteType)) {
return ((Byte) value).doubleValue();
}

if (fieldStruct.dataType().sameType(DataTypes.ShortType)) {
return ((Short) value).doubleValue();
}

if (fieldStruct.dataType().sameType(DataTypes.IntegerType)) {
return ((Integer) value).doubleValue();
}

if (fieldStruct.dataType().sameType(DataTypes.DoubleType)) {
return (Double) value;
}

if (fieldStruct.dataType().sameType(DataTypes.StringType)) {
return Arrays.binarySearch(domain, value);
}

throw new IllegalArgumentException("Target column has to be an enum or a number. " + fieldStruct.toString());
}
}

0 comments on commit b6102f8

Please sign in to comment.