In [1]:
spark

org.apache.spark.sql.SparkSession@63ded03c

### Spark Dataframes
![DataFrame](df1.png "Spark DataFrames")
***


## Spark ML Pipeline
![pipeline](ml-pipeline.png)
****


### WorkFlow
![dataframe](dataframe.png)
****


### Cross Validator for Best Model
![cross](crossvalidator.png)

In [2]:
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.ml.feature.{Bucketizer, OneHotEncoder, StringIndexer, VectorAssembler, StandardScaler}
import com.ibm.snap.ml.{SnapLogisticRegression => LogisticRegression}
import com.ibm.snap.ml.{SnapLogisticRegressionModel => LogisticRegressionModel}
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

import com.ibm.snap.ml.{SnapLogisticRegression=>LogisticRegression}
import com.ibm.snap.ml.{SnapLogisticRegressionModel=>LogisticRegressionModel}


In [3]:
import org.apache.spark.sql.types._

case class Flight(yr: Integer, mon: Integer, dofM: Integer, dofW: Integer, date: String, carrier: String, 
                  origin: String, dest: String, crsdeptime:String, deptime:String, depdelay:Double, 
                  crsarrtime:String, arrtime: String, arrdelay: Double, crselapsedtime: Double, 
                  elapsed: Double, air: Double, dist: Double, dummy: String) extends Serializable

  val schema = StructType(Array(
    StructField("yr", IntegerType, true),
    StructField("mon", IntegerType, true),
    StructField("dofM", IntegerType, true),
    StructField("dofW", IntegerType, true),
    StructField("date", StringType, true),
    StructField("carrier", StringType, true),
    StructField("origin", StringType, true),
    StructField("dest", StringType, true),
    StructField("crsdeptime", IntegerType, true),
    StructField("deptime", StringType, true),  
    StructField("depdelay", DoubleType, true),
    StructField("crsarrtime", IntegerType, true),
    StructField("arrtime", StringType, true),    
    StructField("arrdelay", DoubleType, true),
    StructField("canc", DoubleType, true),  
    StructField("crselapsedtime", DoubleType, true),
    StructField("elapsed", DoubleType, true),  
    StructField("air", DoubleType, true),    
    StructField("dist", DoubleType, true),
    StructField("dummy", StringType, true),
    StructField("crsdephour", IntegerType, true)
  ))

// StructField("crsdephour", IntegerType, true),crsdeptime, crsarrtime, 

defined class Flight
schema = StructType(StructField(yr,IntegerType,true), StructField(mon,IntegerType,true), StructField(dofM,IntegerType,true), StructField(dofW,IntegerType,true), StructField(date,StringType,true), StructField(carrier,StringType,true), StructField(origin,StringType,true), StructField(dest,StringType,true), StructField(crsdeptime,IntegerType,true), StructField(deptime,StringType,true), StructField(depdelay,DoubleType,true), StructField(crsarrtime,IntegerType,true), StructField(arrtime,StringType,true), StructField(arrdelay,DoubleType,true), StructField(canc,DoubleType,true), StructField(crselapsedtime,DoubleType,true), StructField(elapsed,DoubleType,true), StructField(air,DoubleType,true), Struc...


StructType(StructField(yr,IntegerType,true), StructField(mon,IntegerType,true), StructField(dofM,IntegerType,true), StructField(dofW,IntegerType,true), StructField(date,StringType,true), StructField(carrier,StringType,true), StructField(origin,StringType,true), StructField(dest,StringType,true), StructField(crsdeptime,IntegerType,true), StructField(deptime,StringType,true), StructField(depdelay,DoubleType,true), StructField(crsarrtime,IntegerType,true), StructField(arrtime,StringType,true), StructField(arrdelay,DoubleType,true), StructField(canc,DoubleType,true), StructField(crselapsedtime,DoubleType,true), StructField(elapsed,DoubleType,true), StructField(air,DoubleType,true), Struc...

In [4]:
def toInt = udf {(str: String) => {
  str.toInt
}}

def toDouble = udf {(str: String) => {
  if (str == "") {
        0.0
  } else {
      str.toDouble
  }
}}
def getHour = udf {(str: String) => {
    if (str == "") {
        0
    } else {
       str.toInt / 100 
    }
}}

toInt: org.apache.spark.sql.expressions.UserDefinedFunction
toDouble: org.apache.spark.sql.expressions.UserDefinedFunction
getHour: org.apache.spark.sql.expressions.UserDefinedFunction


In [5]:
val colNames = Seq("yr", "mon", "dofM", "dofW", "date", "carrier", "origin", "dest", "crsdeptime", "deptime", 
                   "depdelay", "crsarrtime", "arrtime", "arrdelay", "canc", "crselapsedtime","elapsed", "air", "dist", "dummy")



val datafilesDF = spark.read.format("csv").option("header", "true")
        .load("./data").toDF(colNames :_* )
        .na.drop(Seq("yr", "mon", "dofM", "dofW", "depdelay", "crsarrtime", "crsdeptime", "crselapsedtime", "air", "dist"))
        .withColumn("yr", toInt($"yr"))
        .withColumn("mon", toInt($"mon"))
        .withColumn("dofM", toInt($"dofM"))
        .withColumn("dofW", toInt($"dofW"))
        .withColumn("dofW", toInt($"dofW"))
        .withColumn("depdelay", toDouble($"depdelay"))
        .withColumn("crsarrtime", toInt($"crsarrtime"))
        .withColumn("crsdeptime", toInt($"crsdeptime"))
        .withColumn("arrdelay", toDouble($"arrdelay"))
        .withColumn("canc", toDouble($"canc"))
        .withColumn("crselapsedtime", toDouble($"crselapsedtime"))
        .withColumn("elapsed", toDouble($"elapsed"))
        .withColumn("air", toDouble($"air"))
        .withColumn("dist", toDouble($"dist"))
        .withColumn("crsdephour", getHour($"deptime"))

val flightDS = datafilesDF.as[Flight].cache

colNames = List(yr, mon, dofM, dofW, date, carrier, origin, dest, crsdeptime, deptime, depdelay, crsarrtime, arrtime, arrdelay, canc, crselapsedtime, elapsed, air, dist, dummy)
datafilesDF = [yr: int, mon: int ... 19 more fields]
flightDS = [yr: int, mon: int ... 19 more fields]


[yr: int, mon: int ... 19 more fields]

In [6]:
flightDS.show

+----+---+----+----+----------+-------+------------+---------------+----------+-------+--------+----------+-------+--------+----+--------------+-------+-----+------+-----+----------+
|  yr|mon|dofM|dofW|      date|carrier|      origin|           dest|crsdeptime|deptime|depdelay|crsarrtime|arrtime|arrdelay|canc|crselapsedtime|elapsed|  air|  dist|dummy|crsdephour|
+----+---+----+----+----------+-------+------------+---------------+----------+-------+--------+----------+-------+--------+----+--------------+-------+-----+------+-----+----------+
|2018| 12|  10|   1|2018-12-10|     WN|San Jose, CA|Los Angeles, CA|      2110|   2159|    49.0|      2230|   2306|    36.0| 0.0|          80.0|   67.0| 52.0| 308.0| null|        21|
|2018| 12|  10|   1|2018-12-10|     WN|San Jose, CA|Los Angeles, CA|      1445|   1440|    -5.0|      1605|   1559|    -6.0| 0.0|          80.0|   79.0| 51.0| 308.0| null|        14|
|2018| 12|  10|   1|2018-12-10|     WN|San Jose, CA|    Orlando, FL|      1305|   131

In [7]:
flightDS.count

2351068

In [8]:
flightDS.printSchema

root
 |-- yr: integer (nullable = false)
 |-- mon: integer (nullable = false)
 |-- dofM: integer (nullable = false)
 |-- dofW: integer (nullable = false)
 |-- date: string (nullable = true)
 |-- carrier: string (nullable = true)
 |-- origin: string (nullable = true)
 |-- dest: string (nullable = true)
 |-- crsdeptime: integer (nullable = false)
 |-- deptime: string (nullable = true)
 |-- depdelay: double (nullable = false)
 |-- crsarrtime: integer (nullable = false)
 |-- arrtime: string (nullable = true)
 |-- arrdelay: double (nullable = false)
 |-- canc: double (nullable = false)
 |-- crselapsedtime: double (nullable = false)
 |-- elapsed: double (nullable = false)
 |-- air: double (nullable = false)
 |-- dist: double (nullable = false)
 |-- dummy: string (nullable = true)
 |-- crsdephour: integer (nullable = false)



In [9]:
flightDS.createOrReplaceTempView("flights")
spark.catalog.cacheTable("flights")

In [10]:
import org.apache.spark.mllib.stat.Statistics
val depdelay = flightDS.select("depdelay").map{row:Row => row.getAs[Double]("depdelay")}.rdd
val arrdelay = flightDS.select( "arrdelay").map{row:Row => row.getAs[Double]("arrdelay")}.rdd
val correlation = Statistics.corr(depdelay,arrdelay, "pearson")

depdelay = MapPartitionsRDD[30] at rdd at <console>:56
arrdelay = MapPartitionsRDD[35] at rdd at <console>:57
correlation = 0.9543348573117445


0.9543348573117445

In [11]:
flightDS.describe("dist", "crselapsedtime","depdelay", "arrdelay").show

+-------+-----------------+------------------+------------------+------------------+
|summary|             dist|    crselapsedtime|          depdelay|          arrdelay|
+-------+-----------------+------------------+------------------+------------------+
|  count|          2351068|           2351068|           2351068|           2351068|
|   mean|796.5423390561226|140.32010771275012| 8.128869943361911|3.5410962166981133|
| stddev|595.6734193002211| 72.78377685827463|42.188337576537386| 44.54053869542065|
|    min|             31.0|             -99.0|            -122.0|            -120.0|
|    max|           4983.0|             703.0|            2109.0|            2153.0|
+-------+-----------------+------------------+------------------+------------------+



In [12]:
%%sql 
select carrier,origin, dest, depdelay,crsdephour, dist, dofW
from flights 
order by depdelay desc limit 5 

+-------+--------------------+---------------...


+-------+--------------------+--------------------+--------+----------+------+----+
|carrier|              origin|                dest|depdelay|crsdephour|  dist|dofW|
+-------+--------------------+--------------------+--------+----------+------+----+
|     AA|        Hartford, CT|    Philadelphia, PA|  2109.0|        19| 196.0|   7|
|     OO|Bristol/Johnson C...|         Atlanta, GA|  2098.0|         0| 227.0|   7|
|     AA|  Raleigh/Durham, NC|Dallas/Fort Worth...|  1822.0|        18|1061.0|   6|
|     YV|        Columbus, OH|         Houston, TX|  1789.0|        17| 986.0|   4|
|     AA|          Austin, TX|    Philadelphia, PA|  1787.0|        13|1430.0|   7|
+-------+--------------------+--------------------+--------+----------+------+----+



In [13]:
val delaybucketizer = new Bucketizer().setInputCol("depdelay")
  .setOutputCol("delayed").setSplits(Array(Double.NegativeInfinity,40.0,Double.PositiveInfinity))

val flightDS4= delaybucketizer.transform(flightDS)

flightDS4.groupBy("delayed").count.show

+-------+-------+
|delayed|  count|
+-------+-------+
|    0.0|2158903|
|    1.0| 192165|
+-------+-------+



delaybucketizer = bucketizer_7a358c820f3a
flightDS4 = [yr: int, mon: int ... 20 more fields]


[yr: int, mon: int ... 20 more fields]

In [14]:
val fractions = Map(0.0 -> .3, 1.0->1.0)
val flightDS5 = flightDS4.stat.sampleBy("delayed", fractions, 36L)
flightDS5.groupBy("delayed").count.show

+-------+------+
|delayed| count|
+-------+------+
|    0.0|647139|
|    1.0|192165|
+-------+------+



fractions = Map(0.0 -> 0.3, 1.0 -> 1.0)
flightDS5 = [yr: int, mon: int ... 20 more fields]


[yr: int, mon: int ... 20 more fields]

In [15]:
// categorical Column names
val categoricalColumns = Array( "carrier", "origin", "dest", "dofW")


// String Indexers will encode string categorial columns
// into a column numeric indices
val stringIndexers = categoricalColumns.map { colName =>
      new StringIndexer()
        .setInputCol(colName)
        .setOutputCol(colName + "Indexed")
        .setHandleInvalid("keep")
}


//OneHotEncoders map number indices column to column of binary vectors
val encoders = categoricalColumns.map { colName =>
      new OneHotEncoder()
        .setInputCol(colName+"Indexed")
        .setOutputCol(colName + "Enc")
}

categoricalColumns = Array(carrier, origin, dest, dofW)
stringIndexers = Array(strIdx_2878ba18a705, strIdx_8e21cd7975ff, strIdx_f8adc026ba71, strIdx_e1d915b03ce0)
encoders = Array(oneHot_0ac4bb0bee21, oneHot_9404dd02b4bf, oneHot_c55552a9bd98, oneHot_77a698486db2)




Array(oneHot_0ac4bb0bee21, oneHot_9404dd02b4bf, oneHot_c55552a9bd98, oneHot_77a698486db2)

In [16]:
val labeler = new Bucketizer().setInputCol("depdelay")
   .setOutputCol("label")
   .setSplits(Array( Double.NegativeInfinity, 40.0, Double.PositiveInfinity))

val featureCols = Array( "carrierEnc", "destEnc", "originEnc", 
   "dofWEnc","crsdephour","crselapsedtime","crsarrtime","crsdeptime","dist")

//put features into a feature vector column   
val assembler = new VectorAssembler()
   .setInputCols(featureCols)
   .setOutputCol("features")

labeler = bucketizer_6e3a952d406a
featureCols = Array(carrierEnc, destEnc, originEnc, dofWEnc, crsdephour, crselapsedtime, crsarrtime, crsdeptime, dist)
assembler = vecAssembler_9201fd671880


vecAssembler_9201fd671880

### Set up pipeline with feature transformers

In [17]:
val steps = stringIndexers ++ encoders  ++  Array(labeler, assembler)
val feature_pipeline = new Pipeline().setStages(steps)

steps = Array(strIdx_2878ba18a705, strIdx_8e21cd7975ff, strIdx_f8adc026ba71, strIdx_e1d915b03ce0, oneHot_0ac4bb0bee21, oneHot_9404dd02b4bf, oneHot_c55552a9bd98, oneHot_77a698486db2, bucketizer_6e3a952d406a, vecAssembler_9201fd671880)


feature_pipeline: org.apache.spark.m...


Array(strIdx_2878ba18a705, strIdx_8e21cd7975ff, strIdx_f8adc026ba71, strIdx_e1d915b03ce0, oneHot_0ac4bb0bee21, oneHot_9404dd02b4bf, oneHot_c55552a9bd98, oneHot_77a698486db2, bucketizer_6e3a952d406a, vecAssembler_9201fd671880)

In [18]:
val featuresTransformer = feature_pipeline.fit(flightDS5)

featuresTransformer = pipeline_3dd211133730


pipeline_3dd211133730

In [19]:
featuresTransformer.transform(flightDS5).select("features").show

+--------------------+
|            features|
+--------------------+
|(721,[0,24,394,71...|
|(721,[0,17,394,71...|
|(721,[0,78,394,71...|
|(721,[0,47,394,71...|
|(721,[0,47,394,71...|
|(721,[0,47,394,71...|
|(721,[0,27,394,71...|
|(721,[0,27,394,71...|
|(721,[0,27,394,71...|
|(721,[0,81,394,71...|
|(721,[0,81,394,71...|
|(721,[0,39,394,71...|
|(721,[0,39,394,71...|
|(721,[0,39,394,71...|
|(721,[0,39,394,71...|
|(721,[0,34,394,71...|
|(721,[0,60,394,71...|
|(721,[0,60,394,71...|
|(721,[0,44,394,71...|
|(721,[0,84,394,71...|
+--------------------+
only showing top 20 rows



In [20]:
import com.ibm.snap.ml.{SnapLogisticRegression => LogisticRegression}
val lr = new LogisticRegression().setLabelCol("label")
  .setFeaturesCol("features")

lr = SnapMLLogisticRegression_0a01582789b0


import com.ibm.snap.ml.{SnapLogisticRegression=>LogisticRegression}


SnapMLLogisticRegression_0a01582789b0

In [21]:
val training_pipeline = new Pipeline().setStages(steps ++ Array(lr))

training_pipeline = pipeline_6e6116135751


pipeline_6e6116135751

In [22]:
val paramGrid = new ParamGridBuilder()
    .addGrid(lr.regParam, Array( 0.001, 0.1))
    .addGrid(lr.balanced, Array(true, false))
    .addGrid(lr.maxIter, Array(10))
    .build()

val evaluator = new MulticlassClassificationEvaluator()
    .setLabelCol("label").setPredictionCol("prediction")
    .setMetricName("accuracy")   

// Set up 3-fold cross validation with paramGrid
 val crossval = new CrossValidator().setEstimator(training_pipeline)
      .setEvaluator (evaluator)
      .setEstimatorParamMaps(paramGrid).setNumFolds(2)

paramGrid = 


Array({
	SnapMLLogisticRegression_0a01582789b0-balanced: true,
	SnapMLLogisticRegression_0a01582789b0-maxIter: 10,
	SnapMLLogisticRegression_0a01582789b0-regParam: 0.001
}, {
	SnapMLLogisticRegression_0a01582789b0-balanced: false,
	SnapMLLogisticRegression_0a01582789b0-maxIter: 10,
	SnapMLLogisticRegression_0a01582789b0-regParam: 0.001
}, {
	SnapMLLogisticRegression_0a01582789b0-balanced: true,
	SnapMLLogisticRegression_0a01582789b0-maxIter: 10,
	SnapMLLogisticRegression_0a01582789b0-regParam: 0.1
}, {
	SnapMLLogisticRegression_0a01582789b0-balanced: false,
	SnapMLLogisticRegression_0a01582789b0-maxIter: 10,
	SnapMLLogisticRegression_0a01582789b0-regParam: 0.1
})
evaluator: org.apache.spark.ml.evaluation.MulticlassClassificationEval...


### Use Cross Validator Estimator to fit the training data set

In [23]:
val cvModel = crossval.fit(flightDS5)



cvModel = cv_fc7c65bd2368


cv_fc7c65bd2368

#### 10-fold Cross Validation
![cross1](cross.png)

In [24]:
val lrModel = cvModel.bestModel.asInstanceOf[org.apache.spark.ml.PipelineModel]
            .stages.last
            .asInstanceOf[LogisticRegressionModel]

lrModel = SnapMLLogisticRegression_0a01582789b0


SnapMLLogisticRegression_0a01582789b0

In [25]:
lrModel.explainParams() 

balanced: If set to ‘False’, all classes will have weight 1. (default: false, current: false)
dual: Dual or Primal formulation. Recommendation: if n_samples > n_features use dual=True. (default: true)
elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0)
family: The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial. (default: auto)
featuresCol: features column name (default: features, current: features)
gpuMemLimit: Limit of the GPU memory. If set to the default value 0, the maximum possible memory is used. (default: 0)
labelCol: label column name (default: label, current: label)
...


In [38]:
lrModel.extractParamMap()

{
	SnapMLLogisticRegression_0a01582789b0-balanced: false,
	SnapMLLogisticRegression_0a01582789b0-dual: true,
	SnapMLLogisticRegression_0a01582789b0-elasticNetParam: 0.0,
	SnapMLLogisticRegression_0a01582789b0-family: auto,
	SnapMLLogisticRegression_0a01582789b0-featuresCol: features,
	SnapMLLogisticRegression_0a01582789b0-gpuMemLimit: 0,
	SnapMLLogisticRegression_0a01582789b0-labelCol: label,
	SnapMLLogisticRegression_0a01582789b0-maxIter: 10,
	SnapMLLogisticRegression_0a01582789b0-nthreads: 1,
	SnapMLLogisticRegression_0a01582789b0-predictionCol: prediction,
	SnapMLLogisticRegression_0a01582789b0-probabilityCol: probability,
	SnapMLLogisticRegression_0a01582789b0-rawPredictionCol: rawPrediction,
	SnapMLLogisticRegression_0a01582789b0-regPara...


In [27]:
val testfilesDF = spark.read.format("csv").option("header", "true")
        .load("./test").toDF(colNames :_* )
        .na.drop(Seq("yr", "mon", "dofM", "dofW", "depdelay", "crsarrtime", "crsdeptime", "crselapsedtime", "air", "dist"))
        .withColumn("yr", toInt($"yr"))
        .withColumn("mon", toInt($"mon"))
        .withColumn("dofM", toInt($"dofM"))
        .withColumn("dofW", toInt($"dofW"))
        .withColumn("dofW", toInt($"dofW"))
        .withColumn("depdelay", toDouble($"depdelay"))
        .withColumn("crsarrtime", toInt($"crsarrtime"))
        .withColumn("crsdeptime", toInt($"crsdeptime"))
        .withColumn("arrdelay", toDouble($"arrdelay"))
        .withColumn("canc", toDouble($"canc"))
        .withColumn("crselapsedtime", toDouble($"crselapsedtime"))
        .withColumn("elapsed", toDouble($"elapsed"))
        .withColumn("air", toDouble($"air"))
        .withColumn("dist", toDouble($"dist"))
        .withColumn("crsdephour", getHour($"deptime"))

val testDS = testfilesDF.as[Flight].cache

testfilesDF = [yr: int, mon: int ... 19 more fields]
testDS = [yr: int, mon: int ... 19 more fields]


[yr: int, mon: int ... 19 more fields]

In [28]:
val testDS4 = delaybucketizer.transform(testDS)

testDS4.groupBy("delayed").count.show

+-------+------+
|delayed| count|
+-------+------+
|    0.0|530054|
|    1.0| 46196|
+-------+------+



testDS4 = [yr: int, mon: int ... 20 more fields]


[yr: int, mon: int ... 20 more fields]

In [29]:
val fractions = Map(0.0 -> .1, 1.0->1.0)
val testDS5 = testDS4.stat.sampleBy("delayed", fractions, 36L)
testDS5.groupBy("delayed").count.show

+-------+-----+
|delayed|count|
+-------+-----+
|    0.0|53027|
|    1.0|46196|
+-------+-----+



fractions = Map(0.0 -> 0.1, 1.0 -> 1.0)
testDS5 = [yr: int, mon: int ... 20 more fields]


[yr: int, mon: int ... 20 more fields]

In [30]:
val testDS6 = featuresTransformer.transform(testDS5)

testDS6 = [yr: int, mon: int ... 30 more fields]


[yr: int, mon: int ... 30 more fields]

In [31]:
testDS6.select("features").show

+--------------------+
|            features|
+--------------------+
|(721,[10,20,393,7...|
|(721,[10,34,449,7...|
|(721,[10,34,393,7...|
|(721,[10,100,380,...|
|(721,[10,58,372,7...|
|(721,[10,44,380,7...|
|(721,[10,34,413,7...|
|(721,[10,39,375,7...|
|(721,[10,42,380,7...|
|(721,[10,39,376,7...|
|(721,[10,48,376,7...|
|(721,[10,28,394,7...|
|(721,[10,77,393,7...|
|(721,[10,77,370,7...|
|(721,[10,34,447,7...|
|(721,[10,101,394,...|
|(721,[10,101,380,...|
|(721,[10,51,421,7...|
|(721,[10,47,421,7...|
|(721,[10,26,370,7...|
+--------------------+
only showing top 20 rows



### Get Predictions from Test dataset

In [32]:
val predictions = cvModel.transform(testDS5)

predictions = [yr: int, mon: int ... 33 more fields]


[yr: int, mon: int ... 33 more fields]

### Evaluate the predictions accuracy

In [33]:
val accuracy = evaluator.evaluate(predictions)

accuracy = 0.9054957016014432


0.9054957016014432

In [34]:
val lp = predictions.select("label", "prediction")
lp.show

+-----+----------+
|label|prediction|
+-----+----------+
|  0.0|       0.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  0.0|       1.0|
|  0.0|       1.0|
|  1.0|       1.0|
|  0.0|       1.0|
|  0.0|       0.0|
|  0.0|       0.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  0.0|       1.0|
|  0.0|       1.0|
|  0.0|       0.0|
|  0.0|       1.0|
|  0.0|       1.0|
|  1.0|       1.0|
|  1.0|       0.0|
|  0.0|       1.0|
+-----+----------+
only showing top 20 rows



lp = [label: double, prediction: double]


[label: double, prediction: double]

In [35]:
val counttotal = predictions.count()
val label0count  = lp.filter($"label" === 0.0).count()
val pred0count = lp.filter($"prediction" === 0.0).count()
val label1count = lp.filter($"label" === 1.0).count()
val pred1count = lp.filter($"prediction" === 1.0).count()

val correct = lp.filter($"label" === $"prediction").count()
val wrong = lp.filter(not($"label" === $"prediction")).count()
val ratioWrong = wrong.toDouble / counttotal.toDouble
val ratioCorrect = correct.toDouble / counttotal.toDouble
val truep = lp.filter($"prediction" === 0.0)
    .filter($"label" === $"prediction").count() / counttotal.toDouble
val truen = lp.filter($"prediction" === 1.0)
    .filter($"label" === $"prediction").count() / counttotal.toDouble
val falsep = lp.filter($"prediction" === 0.0)
    .filter(not($"label" === $"prediction")).count() / counttotal.toDouble
val falsen = lp.filter($"prediction" === 1.0)
    .filter(not($"label" === $"prediction")).count() / counttotal.toDouble

counttotal = 99223
label0count = 53027
pred0count = 52336
label1count = 46196
pred1count = 46887
correct = 89846
wrong = 9377
ratioWrong = 0.09450429839855679
ratioCorrect = 0.9054957016014432
truep = 0.4836882577628171
truen = 0.4218074438386261
falsep = 0.04377009362748556
falsen = 0.050734204771071226


0.050734204771071226

In [36]:
cvModel.write.overwrite().save("./FlightModel")

In [37]:
val sameCVModel = CrossValidatorModel.load("./FlightModel")

sameCVModel = cv_fc7c65bd2368


cv_fc7c65bd2368