Skip to content

Commit

Permalink
PR review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mdymczyk committed Oct 7, 2016
1 parent 61d4769 commit ec2039d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
14 changes: 7 additions & 7 deletions ml/src/main/scala/org/apache/spark/ml/FrameMLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ object FrameMLUtils {
}

private[ml] def toDouble(value: Any, fieldStruct: StructField, domain: Array[String]): Double = {
fieldStruct.dataType match {
case DataTypes.ByteType => value.asInstanceOf[Byte].doubleValue
case DataTypes.ShortType => value.asInstanceOf[Short].doubleValue
case DataTypes.IntegerType => value.asInstanceOf[Integer].doubleValue
case DataTypes.DoubleType => value.asInstanceOf[Double]
case DataTypes.StringType => domain.indexOf(value)
case _ => throw new IllegalArgumentException("Target column has to be an enum or a number. " + fieldStruct.toString)
value match {
case b: Byte if fieldStruct.dataType == DataTypes.ByteType => b.doubleValue
case s: Short if fieldStruct.dataType == DataTypes.ShortType => s.doubleValue
case i: Int if fieldStruct.dataType == DataTypes.IntegerType => i.doubleValue
case d: Double if fieldStruct.dataType == DataTypes.DoubleType => d
case string: String if fieldStruct.dataType == DataTypes.StringType => domain.indexOf(string)
case _ => throw new IllegalArgumentException("Target column has to be an enum or a number. " + fieldStruct)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public void init(boolean expensive) {
}
}

if(MissingValuesHandling.NotAllowed.equals(_parms._missing_values_handling)) {
if(MissingValuesHandling.NotAllowed == _parms._missing_values_handling) {
for (int i = 0; i < _train.numCols(); i++) {
Vec vec = _train.vec(i);
String vecName = _train.name(i);
Expand Down Expand Up @@ -186,7 +186,8 @@ public void computeImpl() {
RDD<LabeledPoint> training = points._1();
training.cache();

if(training.count() == 0 && MissingValuesHandling.Skip.equals(_parms._missing_values_handling)) {
if(training.count() == 0 &&
MissingValuesHandling.Skip == _parms._missing_values_handling) {
throw new H2OIllegalArgumentException("No rows left in the dataset after filtering out rows with missing values. Ignore columns with many NAs or set missing_values_handling to 'MeanImputation'.");
}

Expand Down

0 comments on commit ec2039d

Please sign in to comment.