#### Linkages Problem

From Advanced Analytics with Spark, 2nd Edition.  
*2. Introduction to Data Analysis with Scala and Spark*

In [173]:
// Read in the data. Since we're on running on local Spark, only read 1/10 partitions (~500k rows)
val dat = spark.read
    .option("header", "true")
    .option("nullValue", "?")
    .option("inferSchema", "true")
    .csv("./data/linkages.csv")
dat.getClass

/*
Using schema inference requires an extra pass over the data. 
If we know the schema ahead of time, define it using `org.apache.spark.sql.types.StructType`,
and pass it to the reader. 
This can improve performance significantly.
*/

dat = [id_1: int, id_2: int ... 10 more fields]


class org.apache.spark.sql.Dataset

In [174]:
// Exploratory analysis
dat.count

574913

In [81]:
dat.printSchema()

root
 |-- id_1: integer (nullable = true)
 |-- id_2: integer (nullable = true)
 |-- cmp_fname_c1: double (nullable = true)
 |-- cmp_fname_c2: double (nullable = true)
 |-- cmp_lname_c1: double (nullable = true)
 |-- cmp_lname_c2: double (nullable = true)
 |-- cmp_sex: integer (nullable = true)
 |-- cmp_bd: integer (nullable = true)
 |-- cmp_bm: integer (nullable = true)
 |-- cmp_by: integer (nullable = true)
 |-- cmp_plz: integer (nullable = true)
 |-- is_match: boolean (nullable = true)



In [34]:
dat.take(5)

0,1,2,3,4,5,6,7,8,9,10,11
37291,53113,0.833333333333333,,1.0,,1,1,1,1,0,True
39086,47614,1.0,,1.0,,1,1,1,1,1,True
70031,70237,1.0,,1.0,,1,1,1,1,1,True
84795,97439,1.0,,1.0,,1,1,1,1,1,True
36950,42116,1.0,,1.0,1.0,1,1,1,1,1,True


For the above 3 operations, Spark has reread and reprocessed the data each time. To avoid this, we `.cache()` the dataframe.

In [38]:
dat.cache()
dat.count
dat.take(2)

0,1,2,3,4,5,6,7,8,9,10,11
37291,53113,0.833333333333333,,1.0,,1,1,1,1,0,True
39086,47614,1.0,,1.0,,1,1,1,1,1,True


#### Notes on Caching

> Spark defines a few different mechanisms, or StorageLevel values, for persisting data. cache() is shorthand for persist(StorageLevel.MEMORY), which stores the rows as unserialized Java objects. When Spark estimates that a partition will not fit in memory, it simply will not store it, and it will be recomputed the next time it’s needed. This level makes the most sense when the objects will be referenced frequently and/or require low-latency access, because it avoids any serialization overhead. Its drawback is that it takes up larger amounts of memory than its alternatives. Also, holding on to many small objects puts pressure on Java’s garbage collection, which can result in stalls and general slowness.

> Spark also exposes a MEMORY_SER storage level, which allocates large byte buffers in memory and serializes the records into them. When we use the right format (more on this in a bit), serialized data usually takes up two to five times less space than its raw equivalent.

> Spark can use disk for caching data as well. The MEMORY_AND_DISK and MEMORY_AND_DISK_SER are similar to the MEMORY and MEMORY_SER storage levels, respectively. For the latter two, if a partition will not fit in memory, it is simply not stored, meaning that it must be recomputed from its dependencies the next time an action uses it. For the former, Spark spills partitions that will not fit in memory to disk.

> Although both DataFrames and RDDs can be cached, Spark can use the detailed knowledge of the data stored with a data frame available via the DataFrame’s schema to persist the data far more efficiently than it can with Java objects stored inside of RDDs.

> Deciding when to cache data can be an art. The decision typically involves trade-offs between space and speed, with the specter of garbage-collecting looming overhead to occasionally confound things further. In general, data should be cached when it is likely to be referenced by multiple actions, is relatively small compared to the amount of memory/disk available on the cluster, and is expensive to regenerate.

In [45]:
// How many records are matches?

// Using the RDD API
val datRdd = dat.rdd
datRdd.map((x: org.apache.spark.sql.Row) => x.getAs[Boolean]("is_match")).countByValue()

[Stage 63:>                                                         (0 + 0) / 7]

datRdd = MapPartitionsRDD[155] at rdd at <console>:32


Map(false -> 572820, true -> 2093)

In [47]:
// Using the DataFrame API
dat.groupBy("is_match")
    .count()
    .orderBy($"count".desc)
    .show()

+--------+------+                                                               
|is_match| count|
+--------+------+
|   false|572820|
|    true|  2093|
+--------+------+



#### Column references
There are two ways to reference columns:  

`.groupBy("is_match")` - string literal  
`.groupBy($"is_match")` - column object

In the groupBy statement, both would work, but we need to use the Column object notation for orderBy since `.desc` is a method that works on `Column` objects.

In [63]:
// multiple aggregation functions
import org.apache.spark.sql.functions._

dat.agg(avg($"cmp_fname_c1"),
        stddev($"cmp_fname_c1"),
        max($"cmp_fname_c1"),
        min($"cmp_fname_c1"))
    .show()

+------------------+-------------------------+-----------------+-----------------+
| avg(cmp_fname_c1)|stddev_samp(cmp_fname_c1)|max(cmp_fname_c1)|min(cmp_fname_c1)|
+------------------+-------------------------+-----------------+-----------------+
|0.7127592938253411|       0.3889286452463531|              1.0|              0.0|
+------------------+-------------------------+-----------------+-----------------+



In [79]:
// saving the dataframe as a View to access with SQL
// this dataframe is TEMPORARY and is only available in this session
dat.createOrReplaceTempView("dataset")

spark.sql("""
    SELECT *
    FROM dataset
    WHERE is_match = true AND cmp_plz = 1
    LIMIT 5
""").show()

+-----+-----+------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
+-----+-----+------------+------------+------------+------------+-------+------+------+------+-------+--------+
|39086|47614|         1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|70031|70237|         1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|84795|97439|         1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|36950|42116|         1.0|        null|         1.0|         1.0|      1|     1|     1|     1|      1|    true|
|42413|48491|         1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
+-----+-----+------------+------------+------------+------------+-------+------+------+------+-------+--

In [88]:
// Summary statistics
val summary = dat.describe()
summary.select("summary", "id_1", "cmp_fname_c1", "cmp_fname_c2").show()

// note that cmp_fname_c2 has many more nulls than cmp_fname_c1

+-------+------------------+------------------+------------------+              
|summary|              id_1|      cmp_fname_c1|      cmp_fname_c2|
+-------+------------------+------------------+------------------+
|  count|            574913|            574811|             10325|
|   mean|33271.962171667714|0.7127592938253411|0.8977586763518969|
| stddev|23622.669425933756|0.3889286452463531|0.2742577520430532|
|    min|                 1|               0.0|               0.0|
|    max|             99894|               1.0|               1.0|
+-------+------------------+------------------+------------------+



summary = [summary: string, id_1: string ... 10 more fields]


[summary: string, id_1: string ... 10 more fields]

In [100]:
// Summary statistics with the subset where is_match == true

val matches = dat.filter($"is_match" === true)  // using DataFrame API
val matchesSummary = matches.describe()

val misses = dat.where("is_match = false")  // using Spark SQL
val missesSummary = misses.describe()



matches = [id_1: int, id_2: int ... 10 more fields]
matchesSummary = [summary: string, id_1: string ... 10 more fields]
misses = [id_1: int, id_2: int ... 10 more fields]
missesSummary = [summary: string, id_1: string ... 10 more fields]


[summary: string, id_1: string ... 10 more fields]

In [113]:
// Reshaping the previously calculated summary DataFrame from long format...
summary.select("summary", "cmp_plz", "cmp_fname_c1").show()

+-------+--------------------+------------------+
|summary|             cmp_plz|      cmp_fname_c1|
+-------+--------------------+------------------+
|  count|              573618|            574811|
|   mean|0.005494946113964...|0.7127592938253411|
| stddev| 0.07392402321301972|0.3889286452463531|
|    min|                   0|               0.0|
|    max|                   1|               1.0|
+-------+--------------------+------------------+



lastException: Throwable = null


In [125]:
// ... To wide format
val schema = summary.schema
val summaryLong = summary.flatMap(row => {
    val metric = row.getString(0)
    (1 until row.size).map(i => (metric, schema(i).name, row.getString(i).toDouble))
})

summaryLong.filter($"_2" === "cmp_plz").show()

+------+-------+--------------------+
|    _1|     _2|                  _3|
+------+-------+--------------------+
| count|cmp_plz|            573618.0|
|  mean|cmp_plz|0.005494946113964...|
|stddev|cmp_plz| 0.07392402321301972|
|   min|cmp_plz|                 0.0|
|   max|cmp_plz|                 1.0|
+------+-------+--------------------+



schema = StructType(StructField(summary,StringType,true), StructField(id_1,StringType,true), StructField(id_2,StringType,true), StructField(cmp_fname_c1,StringType,true), StructField(cmp_fname_c2,StringType,true), StructField(cmp_lname_c1,StringType,true), StructField(cmp_lname_c2,StringType,true), StructField(cmp_sex,StringType,true), StructField(cmp_bd,StringType,true), StructField(cmp_bm,StringType,true), StructField(cmp_by,StringType,true), StructField(cmp_plz,StringType,true))
summaryLong = [_1: string, _2: string ... 1 more field]


[_1: string, _2: string ... 1 more field]

#### Notes on Implicit Type Conversion (String.toDouble)

> The toDouble method is an example of one of Scala’s most powerful (and arguably dangerous) features: implicit types. In Scala, an instance of the String class is just a java.lang.String, and the java.lang.String class does not have a method named toDouble. Instead, the methods are defined in a Scala class called StringOps. Implicits work like this: if you call a method on a Scala object, and the Scala compiler does not see a definition for that method in the class definition for that object, the compiler will try to convert your object to an instance of a class that does have that method defined. In this case, the compiler will see that Java’s String class does not have a toDouble method defined but that the StringOps class does, and that the StringOps class has a method that can convert an instance of the String class into an instance of the StringOps class. The compiler silently performs the conversion of our String object into a StringOps object, and then calls the toDouble method on the new object.

In [129]:
summaryLong.getClass

class org.apache.spark.sql.Dataset

#### The Dataset[T] Interface

Calling `.flatMap()` on a `DataFrame` type returned the `Dataset[T]` type. `DataFrame` is just an alias for `Dataset[Row]`. In Spark 2.0 (Scala), `Dataset[T]` was added to allow handling of more data types than just `Row`.  

In [139]:
// We can convert Dataset[T] back into a DataFrame with .toDF(columns)
val summaryDFLong = summaryLong.toDF("metric", "field", "value")
val summaryDFLongShort = summaryDF.filter($"field" === "cmp_plz")
summaryDFShort.show()

+------+-------+--------------------+
|metric|  field|               value|
+------+-------+--------------------+
| count|cmp_plz|            573618.0|
|  mean|cmp_plz|0.005494946113964...|
|stddev|cmp_plz| 0.07392402321301972|
|   min|cmp_plz|                 0.0|
|   max|cmp_plz|                 1.0|
+------+-------+--------------------+



summaryDFLong = [metric: string, field: string ... 1 more field]
summaryDFLongShort = [metric: string, field: string ... 1 more field]


[metric: string, field: string ... 1 more field]

In [148]:
// Reshaping from long back to wide
val summaryDFWideShort = summaryDFLongShort
    .groupBy("field")
    .pivot("metric", Array("count", "mean", "stddev", "min", "max"))
    .agg(first("value"))
summaryDFWideShort.show()

+-------+--------+--------------------+-------------------+---+---+
|  field|   count|                mean|             stddev|min|max|
+-------+--------+--------------------+-------------------+---+---+
|cmp_plz|573618.0|0.005494946113964...|0.07392402321301972|0.0|1.0|
+-------+--------+--------------------+-------------------+---+---+



summaryDFWideShort = [field: string, count: double ... 4 more fields]


[field: string, count: double ... 4 more fields]

In [151]:
summary.select("summary", "cmp_plz").show()

+-------+--------------------+
|summary|             cmp_plz|
+-------+--------------------+
|  count|              573618|
|   mean|0.005494946113964...|
| stddev| 0.07392402321301972|
|    min|                   0|
|    max|                   1|
+-------+--------------------+



In [219]:
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._

def gather(data: DataFrame): DataFrame = {
    val schema = data.schema
    val pivoted = data.flatMap(row => {
        val statistic = row.getString(0)
        val rowArray = (1 until row.size)
        rowArray.map(i => (statistic, schema(i).name, row.getString(i).toDouble))
    })
    pivoted.toDF("metric", "field", "value")
}

def spread(data: DataFrame): DataFrame = {
    require(data.columns.contains("metric"), "A column 'metric' must be in the DataFrame")
    
    val metricColumns = Array("count", "max", "mean", "stddev", "min")
    
    data.groupBy("field")
        .pivot("metric", metricColumns)
        .agg(first("value"))
}

gather: (data: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame
spread: (data: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame


+-------+--------+---+--------------------+-------------------+---+
|  field|   count|max|                mean|             stddev|min|
+-------+--------+---+--------------------+-------------------+---+
|cmp_plz|573618.0|1.0|0.005494946113964...|0.07392402321301972|0.0|
+-------+--------+---+--------------------+-------------------+---+



In [244]:
gather(summary.select("summary", "cmp_sex", "cmp_plz")).show(5)

+------+-------+--------------------+
|metric|  field|               value|
+------+-------+--------------------+
| count|cmp_sex|            574913.0|
| count|cmp_plz|            573618.0|
|  mean|cmp_sex|  0.9550923357099248|
|  mean|cmp_plz|0.005494946113964...|
|stddev|cmp_sex| 0.20710152240504442|
+------+-------+--------------------+
only showing top 5 rows



In [245]:
val summaryLongDF = summaryLong
    .toDF("metric", "field", "value")
    .filter($"field" === "cmp_sex" || $"field" === "cmp_plz")
spread(summaryLongDF).show(5)

[Stage 476:>                                                      (0 + 0) / 200]+-------+--------+---+--------------------+-------------------+---+
|  field|   count|max|                mean|             stddev|min|
+-------+--------+---+--------------------+-------------------+---+
|cmp_plz|573618.0|1.0|0.005494946113964...|0.07392402321301972|0.0|
|cmp_sex|574913.0|1.0|  0.9550923357099248|0.20710152240504442|0.0|
+-------+--------+---+--------------------+-------------------+---+



summaryLongDF = [metric: string, field: string ... 1 more field]


[metric: string, field: string ... 1 more field]